{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neural Ordinary Differential Equations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Значительная доля процессов описывается дифференциальными уравнениями, это могут быть эволюция физической системы во времени, медицинское состояние пациента, фундаментальные характеристики фондового рынка и т.д. Данные о таких процессах последовательны и непрерывны по своей природе, в том смысле, что наблюдения - это просто проявления какого-то непрерывно изменяющегося состояния.\n",
    "\n",
    "Есть также и другой тип последовательных данных, это дискретные данные, например, данные NLP задач. Состояния в таких данных меняется дискретно: от одного символа или слова к другому.\n",
    "\n",
    "Сейчас оба типа таких последовательных данных обычно обрабатываются рекуррентными сетями, несмотря на то, что они отличны по своей природе, и похоже, требуют различных подходов.\n",
    "\n",
    "На последней *NIPS-конференции* была представлена одна очень интересная статья, которая может помочь решить эту проблему. Авторы предлагают очень интересный подход, который они назвали **Нейронные Обыкновенные Дифференциальные Уравнения (Neural ODE)**.\n",
    "\n",
    "Здесь я постарался воспроизвести и кратко изложить результаты этой статьи, чтобы сделать знакомство с *ее идеей чуть* более простым. Мне кажется, что эта новая архитектура вполне может найти место в стандартном инструментарии дата-сайентиста наряду со *сверхточными* и рекуррентными сетями."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Постановка проблемы\n",
    "\n",
    "Пусть есть процесс, который подчиняется некоторому неизвестному ОДУ и пусть есть несколько (зашумленных) наблюдений вдоль траектории процесса\n",
    "\n",
    "\n",
    "$$\n",
    "\\frac{dz}{dt} = f(z(t), t) \\tag{1}\n",
    "$$\n",
    "$$\n",
    "\\{(z_0, t_0),(z_1, t_1),...,(z_M, t_M)\\} - \\text{наблюдения}\n",
    "$$\n",
    "\n",
    "Как найти аппроксимацию $\\widehat{f}(z, t, \\theta)$ функции динамики $f(z, t)$?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Сначала рассмотрим более простую задачу: есть только 2 наблюдения, в начале и в конце траектории, $(z_0, t_0), (z_1, t_1)$.\n",
    "Эволюция системы запускается из состояния $z_0, t_0$ на время $t_1 - t_0$ с какой-то параметризованной функцией динамики, используя любой метод эволюции систем ОДУ. После того, как система оказывается в новом состоянии $\\hat{z_1}, t_1$, оно сравнивается с состоянием $z_1$ и разница между ними минимизируется варьированием параметров $\\theta$ функции динамики.\n",
    "\n",
    "Или, более формально, рассмотрим минимизацию функции потерь $L(\\hat{z_1})$:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\n",
    "L(z(t_1)) = L \\Big( \\int_{t_0}^{t_1} f(z(t), t, \\theta)dt \\Big) = L \\big( \\text{ODESolve}(z(t_0), f, t_0, t_1, \\theta) \\big) \\tag{2}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=assets/backprop.png width=600></img>\n",
    "<p style=\"text-align: center\">Картинка 1: Непрерывный *backpropagation* градиента требует решения аугментированного дифференциального уравнения назад во времени. <br /> Стрелки представляют корректировку распространенных назад градиентов градиентами от наблюдений.<br />\n",
    "Иллюстрация из оригинальной статьи</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "В случае, если вам лень лезть в математику, картинка выше дает довольно хорошее представление о том, что происходит.\n",
    "Черная траектория олицетворяет решение ОДУ во время прохода вперед, а красная представляет решение сопряженного ОДУ во время бэкпропагейшена. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Чтобы минимизировать $L$, нужно рассчитать градиенты по всем его параметрами: $z(t_0), t_0, t_1, \\theta$. Чтобы сделать это, сначала нужно определить, как $L$ зависит от состояния в каждый момент времени $(z(t))$: \n",
    "$$\n",
    "a(t) = -\\frac{\\partial L}{\\partial z(t)} \\tag{3}\n",
    "$$\n",
    "$a(t)$ зовется *сопряженным* (*adjoint*) состоянием, его динамика задается другим дифференциальными уравнением, которое можно считать непрерывным аналогом дифференцирования сложной функции (*chain rule*):"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\n",
    "\\frac{d a(t)}{d t} = -a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial z} \\tag{4}\n",
    "$$\n",
    "\n",
    "Вывод этой формулы можно посмотреть в аппендиксе оригинальной статьи."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Векторы в этой статье следует считать строчными векторами, хотя оригинальная статья использует и строчное и столбцовое представление."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Решая диффур (4) назад во времени, получаем зависимость от начального состояния $z(t_0)$:\n",
    "\n",
    "$$\n",
    "\\frac{\\partial L}{\\partial z(t_0)} = \\int_{t_1}^{t_0} a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial z} dt \\tag{5}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Чтобы рассчитать градиент по отношению к $t$ and $\\theta$, можно просто считать их частью состояния. Такое состояние зовется *аугментированным*. Динамика такого состояния тривиально получается из оригинальной динамики:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\n",
    "\\frac{d}{dt} \\begin{bmatrix} z \\\\ \\theta \\\\ t \\end{bmatrix} (t) = f_{\\text{aug}}([z, \\theta, t]) := \\begin{bmatrix} f([z, \\theta, t ]) \\\\ 0 \\\\ 1 \\end{bmatrix} \\tag{6}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Тогда сопряженное состояние к этому аугментированному состоянию:\n",
    "\n",
    "$$\n",
    "a_{\\text{aug}} := \\begin{bmatrix} a \\\\ a_{\\theta} \\\\ a_t \\end{bmatrix}, a_{\\theta}(t) := \\frac{\\partial L}{\\partial \\theta(t)}, a_t(t) := \\frac{\\partial L}{\\partial t(t)} \\tag{7}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Градиент аугментированной динамики:\n",
    "\n",
    "$$\n",
    "\\frac{\\partial f_{\\text{aug}}}{\\partial [z, \\theta, t]} = \\begin{bmatrix} \n",
    "\\frac{\\partial f}{\\partial z} & \\frac{\\partial f}{\\partial \\theta} & \\frac{\\partial f}{\\partial t} \\\\\n",
    "0 & 0 & 0 \\\\\n",
    "0 & 0 & 0\n",
    "\\end{bmatrix} \\tag{8}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Дифференциальное уравнение сопряженного аугментированного состояния из формулы (4) тогда:\n",
    "\n",
    "$$\n",
    "\\frac{d a_{\\text{aug}}}{dt} = - \\begin{bmatrix} a\\frac{\\partial f}{\\partial z} & a\\frac{\\partial f}{\\partial \\theta} & a\\frac{\\partial f}{\\partial t}\\end{bmatrix} \\tag{9}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Решение этого ОДУ назад во времени дает:\n",
    "\n",
    "$$\n",
    "\\frac{\\partial L}{\\partial z(t_0)} = \\int_{t_1}^{t_0} a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial z} dt \\tag{10}\n",
    "$$\n",
    "\n",
    "$$\n",
    "\\frac{\\partial L}{\\partial \\theta} = \\int_{t_1}^{t_0} a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial \\theta} dt \\tag{11}\n",
    "$$\n",
    "\n",
    "$$\n",
    "\\frac{\\partial L}{\\partial t_0} = \\int_{t_1}^{t_0} a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial t} dt \\tag{12}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Что вместе с\n",
    "\n",
    "$$\n",
    "\\frac{\\partial L}{\\partial t_1} = - a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial t} \\tag{13}\n",
    "$$\n",
    "\n",
    "дает градиенты по всем входным параметрам в решатель ОДУ *ODESolve*."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Все градиенты (10), (11), (12), (13) могут быть рассчитаны вместе за один вызов *ODESolve* с динамикой сопряженного аугментированного состояния (9). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=assets/pseudocode.png width=800></img>\n",
    "<div align=\"center\">Иллюстрация из оригинальной статьи</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Алгоритм выше описывает обратное распространения градиента решения ОДУ для последовательных наблюдений. Этот алгоритм реализует в себе все описанное выше."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "В случае нескольких наблюдений на одну траекторию все рассчитывается так же, но в моменты наблюдений обратно распространенный градиент надо корректировать градиентами от текущего наблюдения, как показано в *иллюстрации 1*."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Реализация "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Код ниже - это моя реализация **Нейронных ОДУ**. Я делал это сугубо для лучшего понимания того, что происходит. Впрочем, оно очень близка к тому, что реализовано в [репозитории](https://github.com/rtqichen/torchdiffeq) у авторов статьи. Здесь содержится весь нужный для понимания код в одном месте, он также слегка более закомментированный. Для реального применения и экспериментов все же лучше использовать реализацию авторов оригинальной статьи."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "from IPython.display import clear_output\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import seaborn as sns\n",
    "sns.color_palette(\"bright\")\n",
    "import matplotlib as mpl\n",
    "import matplotlib.cm as cm\n",
    "\n",
    "import torch\n",
    "from torch import Tensor\n",
    "from torch import nn\n",
    "from torch.nn  import functional as F \n",
    "from torch.autograd import Variable\n",
    "\n",
    "use_cuda = torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Для начала надо реализовать любой метод эволюции систем ОДУ. В целях простоты здесь реализован метод Эйлера, хотя подойдет любой явный или неявный метод."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ode_solve(z0, t0, t1, f):\n",
    "    \"\"\"\n",
    "    Простейший метод эволюции ОДУ - метод Эйлера\n",
    "    \"\"\"\n",
    "    h_max = 0.05\n",
    "    n_steps = math.ceil((abs(t1 - t0)/h_max).max().item())\n",
    "\n",
    "    h = (t1 - t0)/n_steps\n",
    "    t = t0\n",
    "    z = z0\n",
    "\n",
    "    for i_step in range(n_steps):\n",
    "        z = z + h * f(z, t)\n",
    "        t = t + h\n",
    "    return z"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Здесь также описан суперкласс параметризованной функции динамики с парочкой полезных методов.\n",
    "\n",
    "Во-первых: нужно возвращать все параметры от которых зависит функция в виде вектора. \n",
    "\n",
    "Во-вторых: надо рассчитывать аугментированную динамику. Эта динамика зависит от градиента параметризованной функции по параметрам и входным данным. Чтобы не приходилось каждый раз для каждой новой архитектуры прописывать градиент руками, воспользуемся методом **torch.autograd.grad**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ODEF(nn.Module):\n",
    "    def forward_with_grad(self, z, t, grad_outputs):\n",
    "        \"\"\"Compute f and a df/dz, a df/dp, a df/dt\"\"\"\n",
    "        batch_size = z.shape[0]\n",
    "\n",
    "        out = self.forward(z, t)\n",
    "\n",
    "        a = grad_outputs\n",
    "        adfdz, adfdt, *adfdp = torch.autograd.grad(\n",
    "            (out,), (z, t) + tuple(self.parameters()), grad_outputs=(a),\n",
    "            allow_unused=True, retain_graph=True\n",
    "        )\n",
    "        # метод grad автоматически суммирует градие*н*ты для всех элементов батча, надо expand их обратно \n",
    "        if adfdp is not None:\n",
    "            adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)\n",
    "            adfdp = adfdp.expand(batch_size, -1) / batch_size\n",
    "        if adfdt is not None:\n",
    "            adfdt = adfdt.expand(batch_size, 1) / batch_size\n",
    "        return out, adfdz, adfdt, adfdp\n",
    "\n",
    "    def flatten_parameters(self):\n",
    "        p_shapes = []\n",
    "        flat_parameters = []\n",
    "        for p in self.parameters():\n",
    "            p_shapes.append(p.size())\n",
    "            flat_parameters.append(p.flatten())\n",
    "        return torch.cat(flat_parameters)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Код ниже описывает прямое и обратное распространение для *Нейронных ОДУ*. Приходится отделить этот код от основного ***torch.nn.Module*** в виде функции ***torch.autograd.Function*** потому, что в последнем можно реализовать произвольный метод обратного распространения, в отличие от модуля. Так что это просто костыль.\n",
    "\n",
    "Эта функция лежит в основе всего подхода *Нейронных ОДУ*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ODEAdjoint(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, z0, t, flat_parameters, func):\n",
    "        assert isinstance(func, ODEF)\n",
    "        bs, *z_shape = z0.size()\n",
    "        time_len = t.size(0)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            z = torch.zeros(time_len, bs, *z_shape).to(z0)\n",
    "            z[0] = z0\n",
    "            for i_t in range(time_len - 1):\n",
    "                z0 = ode_solve(z0, t[i_t], t[i_t+1], func)\n",
    "                z[i_t+1] = z0\n",
    "\n",
    "        ctx.func = func\n",
    "        ctx.save_for_backward(t, z.clone(), flat_parameters)\n",
    "        return z\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, dLdz):\n",
    "        \"\"\"\n",
    "        dLdz shape: time_len, batch_size, *z_shape\n",
    "        \"\"\"\n",
    "        func = ctx.func\n",
    "        t, z, flat_parameters = ctx.saved_tensors\n",
    "        time_len, bs, *z_shape = z.size()\n",
    "        n_dim = np.prod(z_shape)\n",
    "        n_params = flat_parameters.size(0)\n",
    "\n",
    "        # Динамика аугментированной системы, которую надо эволюционировать обратно во времени\n",
    "        def augmented_dynamics(aug_z_i, t_i):\n",
    "            \"\"\"\n",
    "            Тензоры здесь - это срезы по времени\n",
    "            t_i - тензор с размерами: bs, 1\n",
    "            aug_z_i - тензор с размерами: bs, n_dim*2 + n_params + 1\n",
    "            \"\"\"\n",
    "            z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]  # игнорируем параметры и время\n",
    "\n",
    "            # Unflatten z and a\n",
    "            z_i = z_i.view(bs, *z_shape)\n",
    "            a = a.view(bs, *z_shape)\n",
    "            with torch.set_grad_enabled(True):\n",
    "                t_i = t_i.detach().requires_grad_(True)\n",
    "                z_i = z_i.detach().requires_grad_(True)\n",
    "                func_eval, adfdz, adfdt, adfdp = func.forward_with_grad(z_i, t_i, grad_outputs=a)  # bs, *z_shape\n",
    "                adfdz = adfdz.to(z_i) if adfdz is not None else torch.zeros(bs, *z_shape).to(z_i)\n",
    "                adfdp = adfdp.to(z_i) if adfdp is not None else torch.zeros(bs, n_params).to(z_i)\n",
    "                adfdt = adfdt.to(z_i) if adfdt is not None else torch.zeros(bs, 1).to(z_i)\n",
    "\n",
    "            # Flatten f and adfdz\n",
    "            func_eval = func_eval.view(bs, n_dim)\n",
    "            adfdz = adfdz.view(bs, n_dim) \n",
    "            return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)\n",
    "\n",
    "        dLdz = dLdz.view(time_len, bs, n_dim)  # flatten dLdz для удобства\n",
    "        with torch.no_grad():\n",
    "            ## Создадим плейсхолдеры для возвращаемых градиентов\n",
    "            # Распространенные назад сопряженные состояния, которые надо поправить градиентами от наблюдений\n",
    "            adj_z = torch.zeros(bs, n_dim).to(dLdz)\n",
    "            adj_p = torch.zeros(bs, n_params).to(dLdz)\n",
    "            # В отличие от z и p, нужно вернуть градиенты для всех моментов времени\n",
    "            adj_t = torch.zeros(time_len, bs, 1).to(dLdz)\n",
    "\n",
    "            for i_t in range(time_len-1, 0, -1):\n",
    "                z_i = z[i_t]\n",
    "                t_i = t[i_t]\n",
    "                f_i = func(z_i, t_i).view(bs, n_dim)\n",
    "\n",
    "                # Рассчитаем прямые градиенты от наблюдений\n",
    "                dLdz_i = dLdz[i_t]\n",
    "                dLdt_i = torch.bmm(torch.transpose(dLdz_i.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]\n",
    "\n",
    "                # Подправим ими сопряженные состояния\n",
    "                adj_z += dLdz_i\n",
    "                adj_t[i_t] = adj_t[i_t] - dLdt_i\n",
    "\n",
    "                # Упакуем аугментированные переменные в вектор\n",
    "                aug_z = torch.cat((z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z), adj_t[i_t]), dim=-1)\n",
    "\n",
    "                # Решим (эволюционируем) аугментированную систему назад во времени\n",
    "                aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics)\n",
    "\n",
    "                # Распакуем переменные обратно из решенной системы\n",
    "                adj_z[:] = aug_ans[:, n_dim:2*n_dim]\n",
    "                adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]\n",
    "                adj_t[i_t-1] = aug_ans[:, 2*n_dim + n_params:]\n",
    "\n",
    "                del aug_z, aug_ans\n",
    "\n",
    "            ## Подправим сопряженное состояние в нулевой момент времени прямыми градиентами\n",
    "            # Вычислим прямые градиенты\n",
    "            dLdz_0 = dLdz[0]\n",
    "            dLdt_0 = torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]\n",
    "\n",
    "            # Подправим\n",
    "            adj_z += dLdz_0\n",
    "            adj_t[0] = adj_t[0] - dLdt_0\n",
    "        return adj_z.view(bs, *z_shape), adj_t, adj_p, None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Теперь для удобства обернем эту функцию в **nn.Module**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NeuralODE(nn.Module):\n",
    "    def __init__(self, func):\n",
    "        super(NeuralODE, self).__init__()\n",
    "        assert isinstance(func, ODEF)\n",
    "        self.func = func\n",
    "\n",
    "    def forward(self, z0, t=Tensor([0., 1.]), return_whole_sequence=False):\n",
    "        t = t.to(z0)\n",
    "        z = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func)\n",
    "        if return_whole_sequence:\n",
    "            return z\n",
    "        else:\n",
    "            return z[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Применение"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## *Восстановление* реальной функции динамики (проверка подхода) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "В качестве базового теста проверим теперь, правда ли **Neural ODE** могут восстанавливать истинную функцию динамики, используя данные наблюдений.\n",
    "\n",
    "Для этого мы сначала определим функцию динамики ОДУ, эволюционируем на ее основе траектории, а потом попробуем восстановить ее из случайно параметризованной функции динамики."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Для начала проверим простейший случай линейного ОДУ. Функция динамики это просто действие матрицы."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\n",
    "\\frac{dz}{dt} = \\begin{bmatrix}-0.1 & -1.0\\\\1.0 & -0.1\\end{bmatrix} z\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Обучаемая функция параметризована случаной матрицей."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![learning gif](assets/linear_learning.gif)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Далее чуть более изощренная динамика (без гифки, потому что процесс обучения не такой красивый :))  \n",
    "Обучаемая функция здесь это полносвязная сеть с одним скрытам слоем.\n",
    "![complicated result](assets/comp_result.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Далее код этих примеров (спойлер)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearODEF(ODEF):\n",
    "    def __init__(self, W):\n",
    "        super(LinearODEF, self).__init__()\n",
    "        self.lin = nn.Linear(2, 2, bias=False)\n",
    "        self.lin.weight = nn.Parameter(W)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        return self.lin(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Функция динамики это просто матрица"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SpiralFunctionExample(LinearODEF):\n",
    "    def __init__(self):\n",
    "        super(SpiralFunctionExample, self).__init__(Tensor([[-0.1, -1.], [1., -0.1]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Случайно параметризованная матрица"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RandomLinearODEF(LinearODEF):\n",
    "    def __init__(self):\n",
    "        super(RandomLinearODEF, self).__init__(torch.randn(2, 2)/2.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Динамика для более изощренных траекторий"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TestODEF(ODEF):\n",
    "    def __init__(self, A, B, x0):\n",
    "        super(TestODEF, self).__init__()\n",
    "        self.A = nn.Linear(2, 2, bias=False)\n",
    "        self.A.weight = nn.Parameter(A)\n",
    "        self.B = nn.Linear(2, 2, bias=False)\n",
    "        self.B.weight = nn.Parameter(B)\n",
    "        self.x0 = nn.Parameter(x0)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        xTx0 = torch.sum(x*self.x0, dim=1)\n",
    "        dxdt = torch.sigmoid(xTx0) * self.A(x - self.x0) + torch.sigmoid(-xTx0) * self.B(x + self.x0)\n",
    "        return dxdt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Обучаемая динамика в виде полносвязной сети"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NNODEF(ODEF):\n",
    "    def __init__(self, in_dim, hid_dim, time_invariant=False):\n",
    "        super(NNODEF, self).__init__()\n",
    "        self.time_invariant = time_invariant\n",
    "\n",
    "        if time_invariant:\n",
    "            self.lin1 = nn.Linear(in_dim, hid_dim)\n",
    "        else:\n",
    "            self.lin1 = nn.Linear(in_dim+1, hid_dim)\n",
    "        self.lin2 = nn.Linear(hid_dim, hid_dim)\n",
    "        self.lin3 = nn.Linear(hid_dim, in_dim)\n",
    "        self.elu = nn.ELU(inplace=True)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        if not self.time_invariant:\n",
    "            x = torch.cat((x, t), dim=-1)\n",
    "\n",
    "        h = self.elu(self.lin1(x))\n",
    "        h = self.elu(self.lin2(h))\n",
    "        out = self.lin3(h)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_np(x):\n",
    "    return x.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_trajectories(obs=None, times=None, trajs=None, save=None, figsize=(16, 8)):\n",
    "    plt.figure(figsize=figsize)\n",
    "    if obs is not None:\n",
    "        if times is None:\n",
    "            times = [None] * len(obs)\n",
    "        for o, t in zip(obs, times):\n",
    "            o, t = to_np(o), to_np(t)\n",
    "            for b_i in range(o.shape[1]):\n",
    "                plt.scatter(o[:, b_i, 0], o[:, b_i, 1], c=t[:, b_i, 0], cmap=cm.plasma)\n",
    "\n",
    "    if trajs is not None: \n",
    "        for z in trajs:\n",
    "            z = to_np(z)\n",
    "            plt.plot(z[:, 0, 0], z[:, 0, 1], lw=1.5)\n",
    "        if save is not None:\n",
    "            plt.savefig(save)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conduct_experiment(ode_true, ode_trained, n_steps, name, plot_freq=10):\n",
    "    # Create data\n",
    "    z0 = Variable(torch.Tensor([[0.6, 0.3]]))\n",
    "\n",
    "    t_max = 6.29*5\n",
    "    n_points = 200\n",
    "\n",
    "    index_np = np.arange(0, n_points, 1, dtype=np.int)\n",
    "    index_np = np.hstack([index_np[:, None]])\n",
    "    times_np = np.linspace(0, t_max, num=n_points)\n",
    "    times_np = np.hstack([times_np[:, None]])\n",
    "\n",
    "    times = torch.from_numpy(times_np[:, :, None]).to(z0)\n",
    "    obs = ode_true(z0, times, return_whole_sequence=True).detach()\n",
    "    obs = obs + torch.randn_like(obs) * 0.01\n",
    "\n",
    "    # Get trajectory of random timespan \n",
    "    min_delta_time = 1.0\n",
    "    max_delta_time = 5.0\n",
    "    max_points_num = 32\n",
    "    def create_batch():\n",
    "        t0 = np.random.uniform(0, t_max - max_delta_time)\n",
    "        t1 = t0 + np.random.uniform(min_delta_time, max_delta_time)\n",
    "\n",
    "        idx = sorted(np.random.permutation(index_np[(times_np > t0) & (times_np < t1)])[:max_points_num])\n",
    "\n",
    "        obs_ = obs[idx]\n",
    "        ts_ = times[idx]\n",
    "        return obs_, ts_\n",
    "\n",
    "    # Train Neural ODE\n",
    "    optimizer = torch.optim.Adam(ode_trained.parameters(), lr=0.01)\n",
    "    for i in range(n_steps):\n",
    "        obs_, ts_ = create_batch()\n",
    "\n",
    "        z_ = ode_trained(obs_[0], ts_, return_whole_sequence=True)\n",
    "        loss = F.mse_loss(z_, obs_.detach())\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward(retain_graph=True)\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % plot_freq == 0:\n",
    "            z_p = ode_trained(z0, times, return_whole_sequence=True)\n",
    "\n",
    "            plot_trajectories(obs=[obs], times=[times], trajs=[z_p], save=f\"assets/imgs/{name}/{i}.png\")\n",
    "            clear_output(wait=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "ode_true = NeuralODE(SpiralFunctionExample())\n",
    "ode_trained = NeuralODE(RandomLinearODEF())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-15-91ce782d97e7>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mconduct_experiment\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mode_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mode_trained\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m500\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"linear\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-13-4efbc846ff63>\u001b[0m in \u001b[0;36mconduct_experiment\u001b[0;34m(ode_true, ode_trained, n_steps, name, plot_freq)\u001b[0m\n\u001b[1;32m     38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     39\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     41\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    100\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m         \"\"\"\n\u001b[0;32m--> 102\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    103\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    104\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     88\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     89\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m     91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "conduct_experiment(ode_true, ode_trained, 500, \"linear\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func = TestODEF(Tensor([[-0.1, -0.5], [0.5, -0.1]]), Tensor([[0.2, 1.], [-1, 0.2]]), Tensor([[-1., 0.]]))\n",
    "ode_true = NeuralODE(func)\n",
    "\n",
    "func = NNODEF(2, 16, time_invariant=True)\n",
    "ode_trained = NeuralODE(func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conduct_experiment(ode_true, ode_trained, 3000, \"comp\", plot_freq=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Как можно видеть, *Neural ODE* довольно хорошо справляются с восстановлением динамики. То есть концепция в целом работает.  \n",
    "Теперь проверим на чуть более сложной задаче (MNIST, ха-ха)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Neural ODE вдохновленные ResNets "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "В ResNet'ax скрытое состояние меняется по формуле\n",
    "$$\n",
    "h_{t+1} = h_{t} + f(h_{t}, \\theta_{t})\n",
    "$$\n",
    "\n",
    "где $t \\in \\{0...T\\}$ - это номер блока и $f$ это функция выучиваемая слоями внутри блока.\n",
    "\n",
    "В пределе, если брать бесконечное число блоков со все меньшими шагами, мы получаем непрерывную динамику скрытого слоя в виде ОДУ, прямо как то, что было выше.\n",
    "\n",
    "$$\n",
    "\\frac{dh(t)}{dt} = f(h(t), t, \\theta)\n",
    "$$\n",
    "\n",
    "Начиная со входного слоя $h(0)$, мы можем определить выходной слой $h(T)$ как решение этого ОДУ в момент времени T.\n",
    "\n",
    "Теперь мы можем считать $\\theta$ как распределенные (*shared*) параметры между всеми бесконечно малыми блоками."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Проверка Neural ODE архитектуры на MNIST\n",
    "\n",
    "В этой части мы проверим возможность *Neural ODE* быть использованными в виде компонентов в более привычных архитектурах.\n",
    "\n",
    "В частности, мы заменим остаточные (*residual*) блоки на *Neural ODE* в классификаторе MNIST.\n",
    "\n",
    "<img src=\"assets/mnist_example.png\" width=400></img>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Код ниже (спойлер)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def norm(dim):\n",
    "    return nn.BatchNorm2d(dim)\n",
    "\n",
    "def conv3x3(in_feats, out_feats, stride=1):\n",
    "    return nn.Conv2d(in_feats, out_feats, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "\n",
    "def add_time(in_tensor, t):\n",
    "    bs, c, w, h = in_tensor.shape\n",
    "    return torch.cat((in_tensor, t.expand(bs, 1, w, h)), dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvODEF(ODEF):\n",
    "    def __init__(self, dim):\n",
    "        super(ConvODEF, self).__init__()\n",
    "        self.conv1 = conv3x3(dim + 1, dim)\n",
    "        self.norm1 = norm(dim)\n",
    "        self.conv2 = conv3x3(dim + 1, dim)\n",
    "        self.norm2 = norm(dim)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        xt = add_time(x, t)\n",
    "        h = self.norm1(torch.relu(self.conv1(xt)))\n",
    "        ht = add_time(h, t)\n",
    "        dxdt = self.norm2(torch.relu(self.conv2(ht)))\n",
    "        return dxdt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ContinuousNeuralMNISTClassifier(nn.Module):\n",
    "    def __init__(self, ode):\n",
    "        super(ContinuousNeuralMNISTClassifier, self).__init__()\n",
    "        self.downsampling = nn.Sequential(\n",
    "            nn.Conv2d(1, 64, 3, 1),\n",
    "            norm(64),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(64, 64, 4, 2, 1),\n",
    "            norm(64),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(64, 64, 4, 2, 1),\n",
    "        )\n",
    "        self.feature = ode\n",
    "        self.norm = norm(64)\n",
    "        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "        self.fc = nn.Linear(64, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.downsampling(x)\n",
    "        x = self.feature(x)\n",
    "        x = self.norm(x)\n",
    "        x = self.avg_pool(x)\n",
    "        shape = torch.prod(torch.tensor(x.shape[1:])).item()\n",
    "        x = x.view(-1, shape)\n",
    "        out = self.fc(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func = ConvODEF(64)\n",
    "ode = NeuralODE(func)\n",
    "model = ContinuousNeuralMNISTClassifier(ode)\n",
    "if use_cuda:\n",
    "    model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "\n",
    "img_std = 0.3081\n",
    "img_mean = 0.1307\n",
    "\n",
    "\n",
    "batch_size = 32\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    torchvision.datasets.MNIST(\"data/mnist\", train=True, download=True,\n",
    "                             transform=torchvision.transforms.Compose([\n",
    "                                 torchvision.transforms.ToTensor(),\n",
    "                                 torchvision.transforms.Normalize((img_mean,), (img_std,))\n",
    "                             ])\n",
    "    ),\n",
    "    batch_size=batch_size, shuffle=True\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    torchvision.datasets.MNIST(\"data/mnist\", train=False, download=True,\n",
    "                             transform=torchvision.transforms.Compose([\n",
    "                                 torchvision.transforms.ToTensor(),\n",
    "                                 torchvision.transforms.Normalize((img_mean,), (img_std,))\n",
    "                             ])\n",
    "    ),\n",
    "    batch_size=128, shuffle=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    num_items = 0\n",
    "    train_losses = []\n",
    "\n",
    "    model.train()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    print(f\"Training Epoch {epoch}...\")\n",
    "    for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):\n",
    "        if use_cuda:\n",
    "            data = data.cuda()\n",
    "            target = target.cuda()\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = criterion(output, target) \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_losses += [loss.item()]\n",
    "        num_items += data.shape[0]\n",
    "    print('Train loss: {:.5f}'.format(np.mean(train_losses)))\n",
    "    return train_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    accuracy = 0.0\n",
    "    num_items = 0\n",
    "\n",
    "    model.eval()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    print(f\"Testing...\")\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (data, target) in tqdm(enumerate(test_loader),  total=len(test_loader)):\n",
    "            if use_cuda:\n",
    "                data = data.cuda()\n",
    "                target = target.cuda()\n",
    "            output = model(data)\n",
    "            accuracy += torch.sum(torch.argmax(output, dim=1) == target).item()\n",
    "            num_items += data.shape[0]\n",
    "    accuracy = accuracy * 100 / num_items\n",
    "    print(\"Test Accuracy: {:.3f}%\".format(accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "n_epochs = 5\n",
    "test()\n",
    "train_losses = []\n",
    "for epoch in range(1, n_epochs + 1):\n",
    "    train_losses += train(epoch)\n",
    "    test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "plt.figure(figsize=(9, 5))\n",
    "history = pd.DataFrame({\"loss\": train_losses})\n",
    "history[\"cum_data\"] = history.index * batch_size\n",
    "history[\"smooth_loss\"] = history.loss.ewm(halflife=10).mean()\n",
    "history.plot(x=\"cum_data\", y=\"smooth_loss\", figsize=(12, 5), title=\"train error\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "Testing...\n",
    "100% 79/79 [00:01<00:00, 45.69it/s]\n",
    "Test Accuracy: 9.740%\n",
    "\n",
    "Training Epoch 1...\n",
    "100% 1875/1875 [01:15<00:00, 24.69it/s]\n",
    "Train loss: 0.20137\n",
    "Testing...\n",
    "100% 79/79 [00:01<00:00, 46.64it/s]\n",
    "Test Accuracy: 98.680%\n",
    "\n",
    "Training Epoch 2...\n",
    "100% 1875/1875 [01:17<00:00, 24.32it/s]\n",
    "Train loss: 0.05059\n",
    "Testing...\n",
    "100% 79/79 [00:01<00:00, 46.11it/s]\n",
    "Test Accuracy: 97.760%\n",
    "\n",
    "Training Epoch 3...\n",
    "100% 1875/1875 [01:16<00:00, 24.63it/s]\n",
    "Train loss: 0.03808\n",
    "Testing...\n",
    "100% 79/79 [00:01<00:00, 45.65it/s]\n",
    "Test Accuracy: 99.000%\n",
    "\n",
    "Training Epoch 4...\n",
    "100% 1875/1875 [01:17<00:00, 24.28it/s]\n",
    "Train loss: 0.02894\n",
    "Testing...\n",
    "100% 79/79 [00:01<00:00, 45.42it/s]\n",
    "Test Accuracy: 99.130%\n",
    "\n",
    "Training Epoch 5...\n",
    "100% 1875/1875 [01:16<00:00, 24.67it/s]\n",
    "Train loss: 0.02424\n",
    "Testing...\n",
    "100% 79/79 [00:01<00:00, 45.89it/s]\n",
    "Test Accuracy: 99.170%\n",
    "```\n",
    "\n",
    "![train error](assets/train_error.png)\n",
    "\n",
    "После очень грубой тренировки в течение всего 5 эпох и 6 минут обучения, модель уже достигла тестовой ошибки в менее, чем 1%. Можно сказать, что *Нейронные ОДУ* хорошо интегрируются *в виде* компонента в более традиционные сети."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "В своей статье авторы также сравнивают этот классификатор (ODE-Net) с обычной полнозвязной сетью, с ResNet'ом с похожей архитектурой, и с точно такой же архитектурой, только в которой градиент распространяется напрямую через операции в *ODESolve* (без метода сопряженного градиента) (RK-Net).\n",
    "\n",
    "![\"Methods comparison\"](assets/methods_compare.png)\n",
    "<div align=\"center\">Иллюстрация из оригинальной статьи</div>\n",
    "\n",
    "Согласно им, 1-слойная полносвязноая сеть с примерно таким же количеством параметров как *Neural ODE* имеет намного более высокую ошибку на тесте, ResNet с примерно такой же ошибкой имеет намного больше параметров, а RK-Net без метода сопряженного градиента, имеет чуть более высокую ошибку и с линейно растущим потреблением памяти (чем меньше допустимая ошибка, тем больше шагов должен сделать *ODESolve*, что линейно увеличивает потребляемую память с числом шагов)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Авторы в своей имплементации используют неявный метод Рунге-Кутты с адаптивным размером шага, в отличие от простейшего метода Эйлера здесь. Они также изучают некоторые свойства новой архитектуры.\n",
    "\n",
    "![\"Node attrs\"](assets/ode_solver_attrs.png)\n",
    "\n",
    "<div align=\"center\">Характеристика ODE-Net (NFE Forward - количество вычислений функции при прямом проходе)</div>\n",
    "<div align=\"center\">Иллюстрация из оригинальной статьи</div>\n",
    "\n",
    "- (a) Изменение допустимого уровня численной ошибки изменяет количество шагов в прямом распространении.\n",
    "- (b) Время потраченное на прямоу распространение пропорционально количеству вычеслений функции.\n",
    "- (c) Количество вычислений функции при обратном распространение составляет примерно половину от прямого распространения, это указывает на то, что метод сопряженного градиента может быть более вычислительно эффективным, чем распространение градиента напрямую через *ODESolve*.\n",
    "- (d) Как ODE-Net становится все более и более обученным, он требует все больше вычислений функции (все меньший шаг), возможно адаптируясь под восрастающую сложность модели."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Скрытая генеративная функция для моделирования временного ряда "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Neural ODE подходит для обработки непрерывных последовательных данных и тогда, когда траектория лежит в неизвестном скрытом пространстве.\n",
    "\n",
    "В этом разделе мы поэкспер*и*ментируем в генерации непрерывных последовательностей, используя *Neural ODE*, и немножко посмотрим на выученное скрытое пространство.\n",
    "\n",
    "Авторы также сравнивают это с аналогичными последовательност*ями*, сгенерированными Рекуррентными сетями.\n",
    "\n",
    "Эксперимент здесь слегка отличается от соответствующего примера в репозитории авторов, здесь более разнообразное множество траекторий."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Данные"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Обучающие данные состоят из случайных спиралей, половина из которых направлены по-часовой, а вторая - против часовой. Далее случайные подпоследовательности сэмплируются из этих спиралей, обрабатываются кодирующей рекуррентной моделью в обратном направлении, порождая стартовое скрытое состояние, которое затем эволюционирует, создавая траекторию в скрытом пространстве. Это скрытая траектория затем отображается в пространство данных и сравнивается с сэмплированной подпоследовательностью. Таким образом, модель учится генерировать траектории, похожие на датасет."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image.png](assets/spirals_examples.png)\n",
    "<div align=\"center\">Примеры спиралей из датасета</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### VAE как генеративная модель"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Генеративная модель через процедуру сэмплирования:\n",
    "$$\n",
    "z_{t_0} \\sim \\mathcal{N}(0, I)\n",
    "$$\n",
    "\n",
    "$$\n",
    "z_{t_1}, z_{t_2},...,z_{t_M} = \\text{ODESolve}(z_{t_0}, f, \\theta_f, t_0,...,t_M)\n",
    "$$\n",
    "\n",
    "$$\n",
    "x_{t_i} \\sim p(x \\mid z_{t_i};\\theta_x)\n",
    "$$\n",
    "\n",
    "Которая может быть обучена, используя подход вариационных автокодировщиков.\n",
    "\n",
    "1. Пройтись рекуррентным энкодером через временную последовательность назад во времени, чтобы получить параметры $\\mu_{z_{t_0}}$, $\\sigma_{z_{t_0}}$ вариационного апестериорного распределения, а потом сэмплировать из него:\n",
    "$$\n",
    "z_{t_0} \\sim q \\left( z_{t_0} \\mid x_{t_0},...,x_{t_M}; t_0,...,t_M; \\theta_q \\right) = \\mathcal{N} \\left(z_{t_0} \\mid \\mu_{z_{t_0}} \\sigma_{z_{t_0}} \\right)\n",
    "$$\n",
    "2. Получить скрытую траекторию:\n",
    "$$\n",
    "z_{t_1}, z_{t_2},...,z_{t_N} = \\text{ODESolve}(z_{t_0}, f, \\theta_f, t_0,...,t_N), \\text{ где } \\frac{d z}{d t} = f(z, t; \\theta_f)\n",
    "$$\n",
    "3. Отобразить скрытую траекторию в траекторию в данных, используя другую нейросеть: $\\hat{x_{t_i}}(z_{t_i}, t_i; \\theta_x)$\n",
    "4. Максимизировать оценку нижней границы обоснованности (ELBO) для сэмплированной траектории:\n",
    "$$\n",
    "\\text{ELBO} \\approx N \\Big( \\sum_{i=0}^{M} \\log p(x_{t_i} \\mid z_{t_i}(z_{t_0}; \\theta_f); \\theta_x) + KL \\left( q( z_{t_0} \\mid x_{t_0},...,x_{t_M}; t_0,...,t_M; \\theta_q) \\parallel \\mathcal{N}(0, I) \\right) \\Big)\n",
    "$$\n",
    "И в случае Гауссовского апостериорного распределения $p(x \\mid z_{t_i};\\theta_x)$ и известного уровня шума $\\sigma_x$:\n",
    "$$\n",
    "\\text{ELBO} \\approx -N \\Big( \\sum_{i=1}^{M}\\frac{(x_i - \\hat{x_i} )^2}{\\sigma_x^2} - \\log \\sigma_{z_{t_0}}^2 + \\mu_{z_{t_0}}^2 + \\sigma_{z_{t_0}}^2 \\Big) + C\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Граф вычислений скрытой ОДУ модель (*модели?*) можно изобразить вот так\n",
    "![vae_model](assets/vae_model.png)\n",
    "<div align=\"center\">Иллюстрация из оригинальной статьи</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Эту модель можно затем протестировать на то, как она интерполирует траекторию, используя только начальные наблюдения."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Код ниже (спойлер)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNEncoder(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, latent_dim):\n",
    "        super(RNNEncoder, self).__init__()\n",
    "        self.input_dim = input_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "        self.rnn = nn.GRU(input_dim+1, hidden_dim)\n",
    "        self.hid2lat = nn.Linear(hidden_dim, 2*latent_dim)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        # Concatenate time to input\n",
    "        t = t.clone()\n",
    "        t[1:] = t[:-1] - t[1:]\n",
    "        t[0] = 0.\n",
    "        xt = torch.cat((x, t), dim=-1)\n",
    "\n",
    "        _, h0 = self.rnn(xt.flip((0,)))  # Reversed\n",
    "        # Compute latent dimension\n",
    "        z0 = self.hid2lat(h0[0])\n",
    "        z0_mean = z0[:, :self.latent_dim]\n",
    "        z0_log_var = z0[:, self.latent_dim:]\n",
    "        return z0_mean, z0_log_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NeuralODEDecoder(nn.Module):\n",
    "    def __init__(self, output_dim, hidden_dim, latent_dim):\n",
    "        super(NeuralODEDecoder, self).__init__()\n",
    "        self.output_dim = output_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "        func = NNODEF(latent_dim, hidden_dim, time_invariant=True)\n",
    "        self.ode = NeuralODE(func)\n",
    "        self.l2h = nn.Linear(latent_dim, hidden_dim)\n",
    "        self.h2o = nn.Linear(hidden_dim, output_dim)\n",
    "\n",
    "    def forward(self, z0, t):\n",
    "        zs = self.ode(z0, t, return_whole_sequence=True)\n",
    "\n",
    "        hs = self.l2h(zs)\n",
    "        xs = self.h2o(hs)\n",
    "        return xs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ODEVAE(nn.Module):\n",
    "    def __init__(self, output_dim, hidden_dim, latent_dim):\n",
    "        super(ODEVAE, self).__init__()\n",
    "        self.output_dim = output_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "        self.encoder = RNNEncoder(output_dim, hidden_dim, latent_dim)\n",
    "        self.decoder = NeuralODEDecoder(output_dim, hidden_dim, latent_dim)\n",
    "\n",
    "    def forward(self, x, t, MAP=False):\n",
    "        z_mean, z_log_var = self.encoder(x, t)\n",
    "        if MAP:\n",
    "            z = z_mean\n",
    "        else:\n",
    "            z = z_mean + torch.randn_like(z_mean) * torch.exp(0.5 * z_log_var)\n",
    "        x_p = self.decoder(z, t)\n",
    "        return x_p, z, z_mean, z_log_var\n",
    "\n",
    "    def generate_with_seed(self, seed_x, t):\n",
    "        seed_t_len = seed_x.shape[0]\n",
    "        z_mean, z_log_var = self.encoder(seed_x, t[:seed_t_len])\n",
    "        x_p = self.decoder(z_mean, t)\n",
    "        return x_p"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Генерация датасета"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t_max = 6.29*5\n",
    "n_points = 200\n",
    "noise_std = 0.02\n",
    "\n",
    "num_spirals = 1000\n",
    "\n",
    "index_np = np.arange(0, n_points, 1, dtype=np.int)\n",
    "index_np = np.hstack([index_np[:, None]])\n",
    "times_np = np.linspace(0, t_max, num=n_points)\n",
    "times_np = np.hstack([times_np[:, None]] * num_spirals)\n",
    "times = torch.from_numpy(times_np[:, :, None]).to(torch.float32)\n",
    "\n",
    "# Generate random spirals parameters\n",
    "normal01 = torch.distributions.Normal(0, 1.0)\n",
    "\n",
    "x0 = Variable(normal01.sample((num_spirals, 2))) * 2.0  \n",
    "\n",
    "W11 = -0.1 * normal01.sample((num_spirals,)).abs() - 0.05\n",
    "W22 = -0.1 * normal01.sample((num_spirals,)).abs() - 0.05\n",
    "W21 = -1.0 * normal01.sample((num_spirals,)).abs()\n",
    "W12 =  1.0 * normal01.sample((num_spirals,)).abs()\n",
    "\n",
    "xs_list = []\n",
    "for i in range(num_spirals):\n",
    "    if i % 2 == 1: #  Make it counter-clockwise\n",
    "        W21, W12 = W12, W21\n",
    "\n",
    "    func = LinearODEF(Tensor([[W11[i], W12[i]], [W21[i], W22[i]]]))\n",
    "    ode = NeuralODE(func)\n",
    "\n",
    "    xs = ode(x0[i:i+1], times[:, i:i+1], return_whole_sequence=True)\n",
    "    xs_list.append(xs)\n",
    "\n",
    "\n",
    "orig_trajs = torch.cat(xs_list, dim=1).detach()\n",
    "samp_trajs = orig_trajs + torch.randn_like(orig_trajs) * noise_std\n",
    "samp_ts = times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(15, 9))\n",
    "axes = axes.flatten()\n",
    "for i, ax in enumerate(axes):\n",
    "    ax.scatter(samp_trajs[:, i, 0], samp_trajs[:, i, 1], c=samp_ts[:, i, 0], cmap=cm.plasma)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy.random as npr\n",
    "\n",
    "def gen_batch(batch_size, n_sample=100):\n",
    "    n_batches = samp_trajs.shape[1] // batch_size\n",
    "    time_len = samp_trajs.shape[0]\n",
    "    n_sample = min(n_sample, time_len)\n",
    "    for i in range(n_batches):\n",
    "        if n_sample > 0:\n",
    "            t0_idx = npr.multinomial(1, [1. / (time_len - n_sample)] * (time_len - n_sample))\n",
    "            t0_idx = np.argmax(t0_idx)\n",
    "            tM_idx = t0_idx + n_sample\n",
    "        else:\n",
    "            t0_idx = 0\n",
    "            tM_idx = time_len\n",
    "\n",
    "        frm, to = batch_size*i, batch_size*(i+1)\n",
    "        yield samp_trajs[t0_idx:tM_idx, frm:to], samp_ts[t0_idx:tM_idx, frm:to]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Обучение"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae = ODEVAE(2, 64, 6)\n",
    "vae = vae.cuda()\n",
    "if use_cuda:\n",
    "    vae = vae.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optim = torch.optim.Adam(vae.parameters(), betas=(0.9, 0.999), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "preload = False\n",
    "n_epochs = 20000\n",
    "batch_size = 100\n",
    "\n",
    "plot_traj_idx = 1\n",
    "plot_traj = orig_trajs[:, plot_traj_idx:plot_traj_idx+1]\n",
    "plot_obs = samp_trajs[:, plot_traj_idx:plot_traj_idx+1]\n",
    "plot_ts = samp_ts[:, plot_traj_idx:plot_traj_idx+1]\n",
    "if use_cuda:\n",
    "    plot_traj = plot_traj.cuda()\n",
    "    plot_obs = plot_obs.cuda()\n",
    "    plot_ts = plot_ts.cuda()\n",
    "\n",
    "if preload:\n",
    "    vae.load_state_dict(torch.load(\"models/vae_spirals.sd\"))\n",
    "\n",
    "for epoch_idx in range(n_epochs):\n",
    "    losses = []\n",
    "    train_iter = gen_batch(batch_size)\n",
    "    for x, t in train_iter:\n",
    "        optim.zero_grad()\n",
    "        if use_cuda:\n",
    "            x, t = x.cuda(), t.cuda()\n",
    "\n",
    "        max_len = np.random.choice([30, 50, 100])\n",
    "        permutation = np.random.permutation(t.shape[0])\n",
    "        np.random.shuffle(permutation)\n",
    "        permutation = np.sort(permutation[:max_len])\n",
    "\n",
    "        x, t = x[permutation], t[permutation]\n",
    "\n",
    "        x_p, z, z_mean, z_log_var = vae(x, t)\n",
    "        kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean**2 - torch.exp(z_log_var), -1)\n",
    "        loss = 0.5 * ((x-x_p)**2).sum(-1).sum(0) / noise_std**2 + kl_loss\n",
    "        loss = torch.mean(loss)\n",
    "        loss /= max_len\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        losses.append(loss.item())\n",
    "\n",
    "    print(f\"Epoch {epoch_idx}\")\n",
    "\n",
    "    frm, to, to_seed = 0, 200, 50\n",
    "    seed_trajs = samp_trajs[frm:to_seed]\n",
    "    ts = samp_ts[frm:to]\n",
    "    if use_cuda:\n",
    "        seed_trajs = seed_trajs.cuda()\n",
    "        ts = ts.cuda()\n",
    "\n",
    "    samp_trajs_p = to_np(vae.generate_with_seed(seed_trajs, ts))\n",
    "\n",
    "    fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(15, 9))\n",
    "    axes = axes.flatten()\n",
    "    for i, ax in enumerate(axes):\n",
    "        ax.scatter(to_np(seed_trajs[:, i, 0]), to_np(seed_trajs[:, i, 1]), c=to_np(ts[frm:to_seed, i, 0]), cmap=cm.plasma)\n",
    "        ax.plot(to_np(orig_trajs[frm:to, i, 0]), to_np(orig_trajs[frm:to, i, 1]))\n",
    "        ax.plot(samp_trajs_p[:, i, 0], samp_trajs_p[:, i, 1])\n",
    "    plt.show()\n",
    "\n",
    "    print(np.mean(losses), np.median(losses))\n",
    "    clear_output(wait=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spiral_0_idx = 3\n",
    "spiral_1_idx = 6\n",
    "\n",
    "homotopy_p = Tensor(np.linspace(0., 1., 10)[:, None])\n",
    "vae = vae\n",
    "if use_cuda:\n",
    "    homotopy_p = homotopy_p.cuda()\n",
    "    vae = vae.cuda()\n",
    "\n",
    "spiral_0 = orig_trajs[:, spiral_0_idx:spiral_0_idx+1, :]\n",
    "spiral_1 = orig_trajs[:, spiral_1_idx:spiral_1_idx+1, :]\n",
    "ts_0 = samp_ts[:, spiral_0_idx:spiral_0_idx+1, :]\n",
    "ts_1 = samp_ts[:, spiral_1_idx:spiral_1_idx+1, :]\n",
    "if use_cuda:\n",
    "    spiral_0, ts_0 = spiral_0.cuda(), ts_0.cuda()\n",
    "    spiral_1, ts_1 = spiral_1.cuda(), ts_1.cuda()\n",
    "\n",
    "z_cw, _ = vae.encoder(spiral_0, ts_0)\n",
    "z_cc, _ = vae.encoder(spiral_1, ts_1)\n",
    "\n",
    "homotopy_z = z_cw * (1 - homotopy_p) + z_cc * homotopy_p\n",
    "\n",
    "t = torch.from_numpy(np.linspace(0, 6*np.pi, 200))\n",
    "t = t[:, None].expand(200, 10)[:, :, None].cuda()\n",
    "t = t.cuda() if use_cuda else t\n",
    "hom_gen_trajs = vae.decoder(homotopy_z, t)\n",
    "\n",
    "fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(15, 5))\n",
    "axes = axes.flatten()\n",
    "for i, ax in enumerate(axes):\n",
    "    ax.plot(to_np(hom_gen_trajs[:, i, 0]), to_np(hom_gen_trajs[:, i, 1]))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "torch.save(vae.state_dict(), \"models/vae_spirals.sd\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Вот что получается после ночи обучения\n",
    "![spiral reconstruction with seed](assets/spirals_reconstructed.png)\n",
    "<div align=\"center\"> Точки - это зашумленные наблюдения оригинальной траектории (синий), <br /> желтая - это реконструированная и интерполированная траектория, используя точки как входы. <br /> Цвет точки показывает время. </div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Реконструкции некоторых примеров не выглядят слишком хорошими. Может модель недостаточно сложная или недостаточно долго училась. В любом случае реконструкции выглядят очень разумно.\n",
    "\n",
    "Теперь посмотрим что будет, если интерполировать скрытую переменную по-часовой траектории к противо-часовой траектории.\n",
    "![homotopy](assets/spirals_homotopy.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Авторы также сравнивают реконструкции и интерполяции траекторий между *Neural ODE* и простой Рекуррентной сетью. \n",
    "![ode_rnn_comp](assets/ode_rnn_comp.png)\n",
    "<div align=\"center\">Иллюстрация из оригинальной статьи</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Непрерывные Нормализующие Потоки\n",
    "\n",
    "Оригинальная статья также привносит многое в тему Нормализующих Потоков. Нормализующие потоки используются, когда нужно сэмплировать из некоторого сложного распределения, появившегося через замену переменных от некоторого простого распределения (Гауссовского, например), и при этом все еще знать плотность вероятности в точке каждого сэмпла.\n",
    "Авторы показывают, что использование непрерывной замены переменных намного более вычислительно эффективно и интерпретируемо, чем предыдущие методы. \n",
    "\n",
    "*Нормализующие потоки* очень полезны в таких моделях как *Вариационные Автокодировщики*, *Байесовские Нейронные Сети* и других из Баейсовского подхода.\n",
    "\n",
    "Эта тема, впрочем, лежит за пределами *данной* статьи, и тем, кто заинтересовался, следует прочесть оригинальную научную статью.\n",
    "\n",
    "Для затравки:\n",
    "![CNF_NF_comp](assets/CNF_NF_comp.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div align=\"center\">Визуализация трансформации из шума (простого распределения) в данные (сложное распределение) для двух датасетов; <br /> Ось-X показывает трансформацию плотности и сэмплов с течением \"времени\" (для ННП) и \"глубины\" (для НП). <br />Иллюстрация из оригинальной статьи</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Спасибо @bekemax за помощь в правке английской версии текста и за интересные физические комментарии.\n",
    "\n",
    "Это завершает мое небольшое исследование **Neural ODEs**. Спасибо за внимание, надеюсь вам понравилось!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Полезные ссылки\n",
    "\n",
    "   - [Оригинальная статья](https://arxiv.org/abs/1806.07366)\n",
    "   - [Авторский репозиторий](https://github.com/rtqichen/torchdiffeq)\n",
    "   - [Вариационный вывод](https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf)\n",
    "   - [Моя статья про VAE (Русский)](https://habr.com/en/post/331552/)\n",
    "   - [Объяснение VAE](https://www.jeremyjordan.me/variational-autoencoders/)\n",
    "   - [Больше про нормализующие потоки](http://akosiorek.github.io/ml/2018/04/03/norm_flows.html)\n",
    "   - [Variational Inference with Normalizing Flows Paper](https://arxiv.org/abs/1505.05770)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>Neural Ordinary Differential Equations</h1>\n",
    "Значительная доля процессов описывается дифференциальными уравнениями, это могут быть эволюция физической системы во времени, медицинское состояние пациента, фундаментальные характеристики фондового рынка и т.д. Данные о таких процессах последовательны и непрерывны по своей природе, в том смысле, что наблюдения - это просто проявления какого-то непрерывно изменяющегося состояния.\n",
    "\n",
    "Есть также и другой тип последовательных данных, это дискретные данные, например, данные NLP задач. Состояния в таких данных меняется дискретно: от одного символа или слова к другому.\n",
    "\n",
    "Сейчас оба типа таких последовательных данных обычно обрабатываются рекуррентными сетями, несмотря на то, что они отличны по своей природе, и похоже, требуют различных подходов.\n",
    "\n",
    "На последней <em>NIPS-конференции</em> была представлена одна очень интересная статья, которая может помочь решить эту проблему. Авторы предлагают очень интересный подход, который они назвали <strong>Нейронные Обыкновенные Дифференциальные Уравнения (Neural ODE)</strong>.\n",
    "\n",
    "Здесь я постарался воспроизвести и кратко изложить результаты этой статьи, чтобы сделать знакомство с идеей чуть более простым. Мне кажется, что эта новая архитектура вполне может найти место в стандартном инструментарии дата-сайентиста наряду со сверточными и рекуррентными сетями.\n",
    "\n",
    "<h2>Постановка проблемы</h2>\n",
    "Пусть есть процесс, который подчиняется некоторому неизвестному ОДУ и пусть есть несколько (зашумленных) наблюдений вдоль траектории процесса\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7Bdz%7D%7Bdt%7D%20%3D%20f(z(t)%2C%20t)%20%5C%3B%20(1)%0A\" alt=\"\n",
    "\\frac{dz}{dt} = f(z(t), t) \\; (1)\n",
    "\" />\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5C%7B(z_0%2C%20t_0)%2C(z_1%2C%20t_1)%2C...%2C(z_M%2C%20t_M)%5C%7D%20-%20%5Ctext%7B%D0%BD%D0%B0%D0%B1%D0%BB%D1%8E%D0%B4%D0%B5%D0%BD%D0%B8%D1%8F%7D%0A\" alt=\"\n",
    "\\{(z_0, t_0),(z_1, t_1),...,(z_M, t_M)\\} - \\text{наблюдения}\n",
    "\" />\n",
    "Как найти аппроксимацию <img src=\"https://tex.s2cms.ru/svg/%5Cwidehat%7Bf%7D(z%2C%20t%2C%20%5Ctheta)\" alt=\"\\widehat{f}(z, t, \\theta)\" /> функции динамики <img src=\"https://tex.s2cms.ru/svg/f(z%2C%20t)\" alt=\"f(z, t)\" />?\n",
    "\n",
    "Сначала рассмотрим более простую задачу: есть только 2 наблюдения, в начале и в конце траектории, <img src=\"https://tex.s2cms.ru/svg/(z_0%2C%20t_0)%2C%20(z_1%2C%20t_1)\" alt=\"(z_0, t_0), (z_1, t_1)\" />.\n",
    "Эволюция системы запускается из состояния <img src=\"https://tex.s2cms.ru/svg/z_0%2C%20t_0\" alt=\"z_0, t_0\" /> на время <img src=\"https://tex.s2cms.ru/svg/t_1%20-%20t_0\" alt=\"t_1 - t_0\" /> с какой-то параметризованной функцией динамики, используя любой метод эволюции систем ОДУ. После того, как система оказывается в новом состоянии <img src=\"https://tex.s2cms.ru/svg/%5Chat%7Bz_1%7D%2C%20t_1\" alt=\"\\hat{z_1}, t_1\" />, оно сравнивается с состоянием <img src=\"https://tex.s2cms.ru/svg/z_1\" alt=\"z_1\" /> и разница между ними минимизируется варьированием параметров <img src=\"https://tex.s2cms.ru/svg/%5Ctheta\" alt=\"\\theta\" /> функции динамики.\n",
    "\n",
    "Или, более формально, рассмотрим минимизацию функции потерь <img src=\"https://tex.s2cms.ru/svg/L(%5Chat%7Bz_1%7D)\" alt=\"L(\\hat{z_1})\" />:\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0AL(z(t_1))%20%3D%20L%20%5CBig(%20%5Cint_%7Bt_0%7D%5E%7Bt_1%7D%20f(z(t)%2C%20t%2C%20%5Ctheta)dt%20%5CBig)%20%3D%20L%20%5Cbig(%20%5Ctext%7BODESolve%7D(z(t_0)%2C%20f%2C%20t_0%2C%20t_1%2C%20%5Ctheta)%20%5Cbig)%20%20%5C%3B%20(2)%0A\" alt=\"\n",
    "L(z(t_1)) = L \\Big( \\int_{t_0}^{t_1} f(z(t), t, \\theta)dt \\Big) = L \\big( \\text{ODESolve}(z(t_0), f, t_0, t_1, \\theta) \\big)  \\; (2)\n",
    "\" />\n",
    "<img src=\"https://habrastorage.org/webt/7e/65/hf/7e65hfxs1amdqyy_uy6emdwulkg.png\" width=600/>\n",
    "<p style=\"text-align: center\">Картинка 1: Непрерывный *backpropagation* градиента требует решения аугментированного дифференциального уравнения назад во времени. <br /> Стрелки представляют корректировку распространенных назад градиентов градиентами от наблюдений.<br />\n",
    "Иллюстрация из оригинальной статьи</p>\n",
    "Чтобы минимизировать <img src=\"https://tex.s2cms.ru/svg/L\" alt=\"L\" />, нужно рассчитать градиенты по всем его параметрами: <img src=\"https://tex.s2cms.ru/svg/z(t_0)%2C%20t_0%2C%20t_1%2C%20%5Ctheta\" alt=\"z(t_0), t_0, t_1, \\theta\" />. Чтобы сделать это, сначала нужно определить, как <img src=\"https://tex.s2cms.ru/svg/L\" alt=\"L\" /> зависит от состояния в каждый момент времени <img src=\"https://tex.s2cms.ru/svg/(z(t))\" alt=\"(z(t))\" />:\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0Aa(t)%20%3D%20-%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20z(t)%7D%20%5C%3B%20(3)%0A\" alt=\"\n",
    "a(t) = -\\frac{\\partial L}{\\partial z(t)} \\; (3)\n",
    "\" />\n",
    "<img src=\"https://tex.s2cms.ru/svg/a(t)\" alt=\"a(t)\" /> зовется <em>сопряженным</em> (<em>adjoint</em>) состоянием, его динамика задается другим дифференциальными уравнением, которое можно считать непрерывным аналогом дифференцирования сложной функции (<em>chain rule</em>):\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7Bd%20a(t)%7D%7Bd%20t%7D%20%3D%20-a(t)%20%5Cfrac%7B%5Cpartial%20f(z(t)%2C%20t%2C%20%5Ctheta)%7D%7B%5Cpartial%20z%7D%20%5C%3B%20(4)%0A\" alt=\"\n",
    "\\frac{d a(t)}{d t} = -a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial z} \\; (4)\n",
    "\" />\n",
    "Вывод этой формулы можно посмотреть в аппендиксе оригинальной статьи.\n",
    "\n",
    "Векторы в этой статье следует считать строчными векторами, хотя оригинальная статья использует и строчное и столбцовое представление.\n",
    "\n",
    "Решая диффур (4) назад во времени, получаем зависимость от начального состояния <img src=\"https://tex.s2cms.ru/svg/z(t_0)\" alt=\"z(t_0)\" />:\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20z(t_0)%7D%20%3D%20%5Cint_%7Bt_1%7D%5E%7Bt_0%7D%20a(t)%20%5Cfrac%7B%5Cpartial%20f(z(t)%2C%20t%2C%20%5Ctheta)%7D%7B%5Cpartial%20z%7D%20dt%20%5C%3B%20(5)%0A\" alt=\"\n",
    "\\frac{\\partial L}{\\partial z(t_0)} = \\int_{t_1}^{t_0} a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial z} dt \\; (5)\n",
    "\" />\n",
    "Чтобы рассчитать градиент по отношению к <img src=\"https://tex.s2cms.ru/svg/t\" alt=\"t\" /> and <img src=\"https://tex.s2cms.ru/svg/%5Ctheta\" alt=\"\\theta\" />, можно просто считать их частью состояния. Такое состояние зовется <em>аугментированным</em>. Динамика такого состояния тривиально получается из оригинальной динамики:\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7Bd%7D%7Bdt%7D%20%5Cbegin%7Bbmatrix%7D%20z%20%5C%5C%20%5Ctheta%20%5C%5C%20t%20%5Cend%7Bbmatrix%7D%20(t)%20%3D%20f_%7B%5Ctext%7Baug%7D%7D(%5Bz%2C%20%5Ctheta%2C%20t%5D)%20%3A%3D%20%5Cbegin%7Bbmatrix%7D%20f(%5Bz%2C%20%5Ctheta%2C%20t%20%5D)%20%5C%5C%200%20%5C%5C%201%20%5Cend%7Bbmatrix%7D%20%5C%3B%20(6)%0A\" alt=\"\n",
    "\\frac{d}{dt} \\begin{bmatrix} z \\\\ \\theta \\\\ t \\end{bmatrix} (t) = f_{\\text{aug}}([z, \\theta, t]) := \\begin{bmatrix} f([z, \\theta, t ]) \\\\ 0 \\\\ 1 \\end{bmatrix} \\; (6)\n",
    "\" />\n",
    "Тогда сопряженное состояние к этому аугментированному состоянию:\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0Aa_%7B%5Ctext%7Baug%7D%7D%20%3A%3D%20%5Cbegin%7Bbmatrix%7D%20a%20%5C%5C%20a_%7B%5Ctheta%7D%20%5C%5C%20a_t%20%5Cend%7Bbmatrix%7D%2C%20a_%7B%5Ctheta%7D(t)%20%3A%3D%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Ctheta(t)%7D%2C%20a_t(t)%20%3A%3D%20%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20t(t)%7D%20%5C%3B%20(7)%0A\" alt=\"\n",
    "a_{\\text{aug}} := \\begin{bmatrix} a \\\\ a_{\\theta} \\\\ a_t \\end{bmatrix}, a_{\\theta}(t) := \\frac{\\partial L}{\\partial \\theta(t)}, a_t(t) := \\frac{\\partial L}{\\partial t(t)} \\; (7)\n",
    "\" />\n",
    "Градиент аугментированной динамики:\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7B%5Cpartial%20f_%7B%5Ctext%7Baug%7D%7D%7D%7B%5Cpartial%20%5Bz%2C%20%5Ctheta%2C%20t%5D%7D%20%3D%20%5Cbegin%7Bbmatrix%7D%20%0A%5Cfrac%7B%5Cpartial%20f%7D%7B%5Cpartial%20z%7D%20%26%20%5Cfrac%7B%5Cpartial%20f%7D%7B%5Cpartial%20%5Ctheta%7D%20%26%20%5Cfrac%7B%5Cpartial%20f%7D%7B%5Cpartial%20t%7D%20%5C%5C%0A0%20%26%200%20%26%200%20%5C%5C%0A0%20%26%200%20%26%200%0A%5Cend%7Bbmatrix%7D%20%5C%3B%20(8)%0A\" alt=\"\n",
    "\\frac{\\partial f_{\\text{aug}}}{\\partial [z, \\theta, t]} = \\begin{bmatrix} \n",
    "\\frac{\\partial f}{\\partial z} &; \\frac{\\partial f}{\\partial \\theta} &; \\frac{\\partial f}{\\partial t} \\\\\n",
    "0 &; 0 &; 0 \\\\\n",
    "0 &; 0 &; 0\n",
    "\\end{bmatrix} \\; (8)\n",
    "\" />\n",
    "Дифференциальное уравнение сопряженного аугментированного состояния из формулы (4) тогда:\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7Bd%20a_%7B%5Ctext%7Baug%7D%7D%7D%7Bdt%7D%20%3D%20-%20%5Cbegin%7Bbmatrix%7D%20a%5Cfrac%7B%5Cpartial%20f%7D%7B%5Cpartial%20z%7D%20%26%20a%5Cfrac%7B%5Cpartial%20f%7D%7B%5Cpartial%20%5Ctheta%7D%20%26%20a%5Cfrac%7B%5Cpartial%20f%7D%7B%5Cpartial%20t%7D%5Cend%7Bbmatrix%7D%20%5C%3B%20(9)%0A\" alt=\"\n",
    "\\frac{d a_{\\text{aug}}}{dt} = - \\begin{bmatrix} a\\frac{\\partial f}{\\partial z} &; a\\frac{\\partial f}{\\partial \\theta} &; a\\frac{\\partial f}{\\partial t}\\end{bmatrix} \\; (9)\n",
    "\" />\n",
    "Решение этого ОДУ назад во времени дает:\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20z(t_0)%7D%20%3D%20%5Cint_%7Bt_1%7D%5E%7Bt_0%7D%20a(t)%20%5Cfrac%7B%5Cpartial%20f(z(t)%2C%20t%2C%20%5Ctheta)%7D%7B%5Cpartial%20z%7D%20dt%20%5C%3B%20(10)%0A\" alt=\"\n",
    "\\frac{\\partial L}{\\partial z(t_0)} = \\int_{t_1}^{t_0} a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial z} dt \\; (10)\n",
    "\" />\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20%5Ctheta%7D%20%3D%20%5Cint_%7Bt_1%7D%5E%7Bt_0%7D%20a(t)%20%5Cfrac%7B%5Cpartial%20f(z(t)%2C%20t%2C%20%5Ctheta)%7D%7B%5Cpartial%20%5Ctheta%7D%20dt%20%5C%3B%20(11)%0A\" alt=\"\n",
    "\\frac{\\partial L}{\\partial \\theta} = \\int_{t_1}^{t_0} a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial \\theta} dt \\; (11)\n",
    "\" />\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20t_0%7D%20%3D%20%5Cint_%7Bt_1%7D%5E%7Bt_0%7D%20a(t)%20%5Cfrac%7B%5Cpartial%20f(z(t)%2C%20t%2C%20%5Ctheta)%7D%7B%5Cpartial%20t%7D%20dt%20%5C%3B%20(12)%0A\" alt=\"\n",
    "\\frac{\\partial L}{\\partial t_0} = \\int_{t_1}^{t_0} a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial t} dt \\; (12)\n",
    "\" />\n",
    "Что вместе с\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7B%5Cpartial%20L%7D%7B%5Cpartial%20t_1%7D%20%3D%20-%20a(t)%20%5Cfrac%7B%5Cpartial%20f(z(t)%2C%20t%2C%20%5Ctheta)%7D%7B%5Cpartial%20t%7D%20%5C%3B%20(13)%0A\" alt=\"\n",
    "\\frac{\\partial L}{\\partial t_1} = - a(t) \\frac{\\partial f(z(t), t, \\theta)}{\\partial t} \\; (13)\n",
    "\" />\n",
    "дает градиенты по всем входным параметрам в решатель ОДУ <em>ODESolve</em>.\n",
    "\n",
    "Все градиенты (10), (11), (12), (13) могут быть рассчитаны вместе за один вызов <em>ODESolve</em> с динамикой сопряженного аугментированного состояния (9).\n",
    "<img src=\"https://habrastorage.org/webt/8k/pz/uk/8kpzukmizpmezmywov4b3zm29lc.png\" width=800/>\n",
    "<div align=\"center\">Иллюстрация из оригинальной статьи</div>\n",
    "Алгоритм выше описывает обратное распространения градиента решения ОДУ для последовательных наблюдений. Этот алгоритм реализует в себе все описанное выше.\n",
    "\n",
    "В случае нескольких наблюдений на одну траекторию все рассчитывается так же, но в моменты наблюдений обратно распространенный градиент надо корректировать градиентами от текущего наблюдения, как показано в <em>иллюстрации 1</em>.\n",
    "\n",
    "<h1>Реализация</h1>\n",
    "Код ниже - это моя реализация <strong>Нейронных ОДУ</strong>. Я делал это сугубо для лучшего понимания того, что происходит. Впрочем, оно очень близка к тому, что реализовано в <a href=\"https://github.com/rtqichen/torchdiffeq\">репозитории</a> у авторов статьи. Здесь содержится весь нужный для понимания код в одном месте, он также слегка более закомментированный. Для реального применения и экспериментов все же лучше использовать реализацию авторов оригинальной статьи.\n",
    "\n",
    "<source lang=\"python\">import math\n",
    "import numpy as np\n",
    "from IPython.display import clear_output\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import seaborn as sns\n",
    "sns.color_palette(\"bright\")\n",
    "import matplotlib as mpl\n",
    "import matplotlib.cm as cm\n",
    "\n",
    "import torch\n",
    "from torch import Tensor\n",
    "from torch import nn\n",
    "from torch.nn  import functional as F \n",
    "from torch.autograd import Variable\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "</source>\n",
    "Для начала надо реализовать любой метод эволюции систем ОДУ. В целях простоты здесь реализован метод Эйлера, хотя подойдет любой явный или неявный метод.\n",
    "\n",
    "<source lang=\"python\">def ode_solve(z0, t0, t1, f):\n",
    "    \"\"\"\n",
    "    Простейший метод эволюции ОДУ - метод Эйлера\n",
    "    \"\"\"\n",
    "    h_max = 0.05\n",
    "    n_steps = math.ceil((abs(t1 - t0)/h_max).max().item())\n",
    "\n",
    "    h = (t1 - t0)/n_steps\n",
    "    t = t0\n",
    "    z = z0\n",
    "\n",
    "    for i_step in range(n_steps):\n",
    "        z = z + h * f(z, t)\n",
    "        t = t + h\n",
    "    return z\n",
    "</source>\n",
    "Здесь также описан суперкласс параметризованной функции динамики с парочкой полезных методов.\n",
    "\n",
    "Во-первых: нужно возвращать все параметры от которых зависит функция в виде вектора.\n",
    "\n",
    "Во-вторых: надо рассчитывать аугментированную динамику. Эта динамика зависит от градиента параметризованной функции по параметрам и входным данным. Чтобы не приходилось каждый раз для каждой новой архитектуры прописывать градиент руками, воспользуемся методом <strong>torch.autograd.grad</strong>.\n",
    "\n",
    "<source lang=\"python\">class ODEF(nn.Module):\n",
    "    def forward_with_grad(self, z, t, grad_outputs):\n",
    "        \"\"\"Compute f and a df/dz, a df/dp, a df/dt\"\"\"\n",
    "        batch_size = z.shape[0]\n",
    "\n",
    "        out = self.forward(z, t)\n",
    "\n",
    "        a = grad_outputs\n",
    "        adfdz, adfdt, *adfdp = torch.autograd.grad(\n",
    "            (out,), (z, t) + tuple(self.parameters()), grad_outputs=(a),\n",
    "            allow_unused=True, retain_graph=True\n",
    "        )\n",
    "        # метод grad автоматически суммирует градие*н*ты для всех элементов батча, надо expand их обратно \n",
    "        if adfdp is not None:\n",
    "            adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)\n",
    "            adfdp = adfdp.expand(batch_size, -1) / batch_size\n",
    "        if adfdt is not None:\n",
    "            adfdt = adfdt.expand(batch_size, 1) / batch_size\n",
    "        return out, adfdz, adfdt, adfdp\n",
    "\n",
    "    def flatten_parameters(self):\n",
    "        p_shapes = []\n",
    "        flat_parameters = []\n",
    "        for p in self.parameters():\n",
    "            p_shapes.append(p.size())\n",
    "            flat_parameters.append(p.flatten())\n",
    "        return torch.cat(flat_parameters)\n",
    "</source>\n",
    "Код ниже описывает прямое и обратное распространение для <em>Нейронных ОДУ</em>. Приходится отделить этот код от основного <strong><em>torch.nn.Module</em></strong> в виде функции <strong><em>torch.autograd.Function</em></strong> потому, что в последнем можно реализовать произвольный метод обратного распространения, в отличие от модуля. Так что это просто костыль.\n",
    "\n",
    "Эта функция лежит в основе всего подхода <em>Нейронных ОДУ</em>.\n",
    "\n",
    "<source lang=\"python\">class ODEAdjoint(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, z0, t, flat_parameters, func):\n",
    "        assert isinstance(func, ODEF)\n",
    "        bs, *z_shape = z0.size()\n",
    "        time_len = t.size(0)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            z = torch.zeros(time_len, bs, *z_shape).to(z0)\n",
    "            z[0] = z0\n",
    "            for i_t in range(time_len - 1):\n",
    "                z0 = ode_solve(z0, t[i_t], t[i_t+1], func)\n",
    "                z[i_t+1] = z0\n",
    "\n",
    "        ctx.func = func\n",
    "        ctx.save_for_backward(t, z.clone(), flat_parameters)\n",
    "        return z\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, dLdz):\n",
    "        \"\"\"\n",
    "        dLdz shape: time_len, batch_size, *z_shape\n",
    "        \"\"\"\n",
    "        func = ctx.func\n",
    "        t, z, flat_parameters = ctx.saved_tensors\n",
    "        time_len, bs, *z_shape = z.size()\n",
    "        n_dim = np.prod(z_shape)\n",
    "        n_params = flat_parameters.size(0)\n",
    "\n",
    "        # Динамика аугментированной системы, которую надо эволюционировать обратно во времени\n",
    "        def augmented_dynamics(aug_z_i, t_i):\n",
    "            \"\"\"\n",
    "            Тензоры здесь - это срезы по времени\n",
    "            t_i - тензор с размерами: bs, 1\n",
    "            aug_z_i - тензор с размерами: bs, n_dim*2 + n_params + 1\n",
    "            \"\"\"\n",
    "            z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]  # игнорируем параметры и время\n",
    "\n",
    "            # Unflatten z and a\n",
    "            z_i = z_i.view(bs, *z_shape)\n",
    "            a = a.view(bs, *z_shape)\n",
    "            with torch.set_grad_enabled(True):\n",
    "                t_i = t_i.detach().requires_grad_(True)\n",
    "                z_i = z_i.detach().requires_grad_(True)\n",
    "                func_eval, adfdz, adfdt, adfdp = func.forward_with_grad(z_i, t_i, grad_outputs=a)  # bs, *z_shape\n",
    "                adfdz = adfdz.to(z_i) if adfdz is not None else torch.zeros(bs, *z_shape).to(z_i)\n",
    "                adfdp = adfdp.to(z_i) if adfdp is not None else torch.zeros(bs, n_params).to(z_i)\n",
    "                adfdt = adfdt.to(z_i) if adfdt is not None else torch.zeros(bs, 1).to(z_i)\n",
    "\n",
    "            # Flatten f and adfdz\n",
    "            func_eval = func_eval.view(bs, n_dim)\n",
    "            adfdz = adfdz.view(bs, n_dim) \n",
    "            return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)\n",
    "\n",
    "        dLdz = dLdz.view(time_len, bs, n_dim)  # flatten dLdz для удобства\n",
    "        with torch.no_grad():\n",
    "            ## Создадим плейсхолдеры для возвращаемых градиентов\n",
    "            # Распространенные назад сопряженные состояния, которые надо поправить градиентами от наблюдений\n",
    "            adj_z = torch.zeros(bs, n_dim).to(dLdz)\n",
    "            adj_p = torch.zeros(bs, n_params).to(dLdz)\n",
    "            # В отличие от z и p, нужно вернуть градиенты для всех моментов времени\n",
    "            adj_t = torch.zeros(time_len, bs, 1).to(dLdz)\n",
    "\n",
    "            for i_t in range(time_len-1, 0, -1):\n",
    "                z_i = z[i_t]\n",
    "                t_i = t[i_t]\n",
    "                f_i = func(z_i, t_i).view(bs, n_dim)\n",
    "\n",
    "                # Рассчитаем прямые градиенты от наблюдений\n",
    "                dLdz_i = dLdz[i_t]\n",
    "                dLdt_i = torch.bmm(torch.transpose(dLdz_i.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]\n",
    "\n",
    "                # Подправим ими сопряженные состояния\n",
    "                adj_z += dLdz_i\n",
    "                adj_t[i_t] = adj_t[i_t] - dLdt_i\n",
    "\n",
    "                # Упакуем аугментированные переменные в вектор\n",
    "                aug_z = torch.cat((z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z), adj_t[i_t]), dim=-1)\n",
    "\n",
    "                # Решим (эволюционируем) аугментированную систему назад во времени\n",
    "                aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics)\n",
    "\n",
    "                # Распакуем переменные обратно из решенной системы\n",
    "                adj_z[:] = aug_ans[:, n_dim:2*n_dim]\n",
    "                adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]\n",
    "                adj_t[i_t-1] = aug_ans[:, 2*n_dim + n_params:]\n",
    "\n",
    "                del aug_z, aug_ans\n",
    "\n",
    "            ## Подправим сопряженное состояние в нулевой момент времени прямыми градиентами\n",
    "            # Вычислим прямые градиенты\n",
    "            dLdz_0 = dLdz[0]\n",
    "            dLdt_0 = torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]\n",
    "\n",
    "            # Подправим\n",
    "            adj_z += dLdz_0\n",
    "            adj_t[0] = adj_t[0] - dLdt_0\n",
    "        return adj_z.view(bs, *z_shape), adj_t, adj_p, None\n",
    "</source>\n",
    "Теперь для удобства обернем эту функцию в <strong>nn.Module</strong>.\n",
    "\n",
    "<source lang=\"python\">class NeuralODE(nn.Module):\n",
    "    def __init__(self, func):\n",
    "        super(NeuralODE, self).__init__()\n",
    "        assert isinstance(func, ODEF)\n",
    "        self.func = func\n",
    "\n",
    "    def forward(self, z0, t=Tensor([0., 1.]), return_whole_sequence=False):\n",
    "        t = t.to(z0)\n",
    "        z = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func)\n",
    "        if return_whole_sequence:\n",
    "            return z\n",
    "        else:\n",
    "            return z[-1]\n",
    "</source>\n",
    "<h1>Применение</h1>\n",
    "<h2><em>Восстановление</em> реальной функции динамики (проверка подхода)</h2>\n",
    "В качестве базового теста проверим теперь, правда ли <strong>Neural ODE</strong> могут восстанавливать истинную функцию динамики, используя данные наблюдений.\n",
    "\n",
    "Для этого мы сначала определим функцию динамики ОДУ, эволюционируем на ее основе траектории, а потом попробуем восстановить ее из случайно параметризованной функции динамики.\n",
    "\n",
    "Для начала проверим простейший случай линейного ОДУ. Функция динамики это просто действие матрицы.\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7Bdz%7D%7Bdt%7D%20%3D%20%5Cbegin%7Bbmatrix%7D-0.1%20%26%20-1.0%5C%5C1.0%20%26%20-0.1%5Cend%7Bbmatrix%7D%20z%0A\" alt=\"\n",
    "\\frac{dz}{dt} = \\begin{bmatrix}-0.1 &; -1.0\\\\1.0 &; -0.1\\end{bmatrix} z\n",
    "\" />\n",
    "Обучаемая функция параметризована случаной матрицей.\n",
    "<img src=\"https://habrastorage.org/webt/nj/q6/es/njq6eswuevh4rhx-8nhzyz-cq_k.gif\" />\n",
    "Далее чуть более изощренная динамика (без гифки, потому что процесс обучения не такой красивый :))<br>\n",
    "Обучаемая функция здесь это полносвязная сеть с одним скрытам слоем.\n",
    "<img src=\"https://habrastorage.org/webt/tr/4n/eb/tr4nebjdrs4pt4v5vlzleai8exe.png\" />\n",
    "\n",
    "Далее код этих примеров (спойлер)\n",
    "\n",
    "<source lang=\"python\">class LinearODEF(ODEF):\n",
    "    def __init__(self, W):\n",
    "        super(LinearODEF, self).__init__()\n",
    "        self.lin = nn.Linear(2, 2, bias=False)\n",
    "        self.lin.weight = nn.Parameter(W)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        return self.lin(x)\n",
    "</source>\n",
    "Функция динамики это просто матрица\n",
    "\n",
    "<source lang=\"python\">class SpiralFunctionExample(LinearODEF):\n",
    "    def __init__(self):\n",
    "        super(SpiralFunctionExample, self).__init__(Tensor([[-0.1, -1.], [1., -0.1]]))\n",
    "</source>\n",
    "Случайно параметризованная матрица\n",
    "\n",
    "<source lang=\"python\">class RandomLinearODEF(LinearODEF):\n",
    "    def __init__(self):\n",
    "        super(RandomLinearODEF, self).__init__(torch.randn(2, 2)/2.)\n",
    "</source>\n",
    "Динамика для более изощренных траекторий\n",
    "\n",
    "<source lang=\"python\">class TestODEF(ODEF):\n",
    "    def __init__(self, A, B, x0):\n",
    "        super(TestODEF, self).__init__()\n",
    "        self.A = nn.Linear(2, 2, bias=False)\n",
    "        self.A.weight = nn.Parameter(A)\n",
    "        self.B = nn.Linear(2, 2, bias=False)\n",
    "        self.B.weight = nn.Parameter(B)\n",
    "        self.x0 = nn.Parameter(x0)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        xTx0 = torch.sum(x*self.x0, dim=1)\n",
    "        dxdt = torch.sigmoid(xTx0) * self.A(x - self.x0) + torch.sigmoid(-xTx0) * self.B(x + self.x0)\n",
    "        return dxdt\n",
    "</source>\n",
    "Обучаемая динамика в виде полносвязной сети\n",
    "\n",
    "<source lang=\"python\">class NNODEF(ODEF):\n",
    "    def __init__(self, in_dim, hid_dim, time_invariant=False):\n",
    "        super(NNODEF, self).__init__()\n",
    "        self.time_invariant = time_invariant\n",
    "\n",
    "        if time_invariant:\n",
    "            self.lin1 = nn.Linear(in_dim, hid_dim)\n",
    "        else:\n",
    "            self.lin1 = nn.Linear(in_dim+1, hid_dim)\n",
    "        self.lin2 = nn.Linear(hid_dim, hid_dim)\n",
    "        self.lin3 = nn.Linear(hid_dim, in_dim)\n",
    "        self.elu = nn.ELU(inplace=True)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        if not self.time_invariant:\n",
    "            x = torch.cat((x, t), dim=-1)\n",
    "\n",
    "        h = self.elu(self.lin1(x))\n",
    "        h = self.elu(self.lin2(h))\n",
    "        out = self.lin3(h)\n",
    "        return out\n",
    "\n",
    "def to_np(x):\n",
    "    return x.detach().cpu().numpy()\n",
    "\n",
    "def plot_trajectories(obs=None, times=None, trajs=None, save=None, figsize=(16, 8)):\n",
    "    plt.figure(figsize=figsize)\n",
    "    if obs is not None:\n",
    "        if times is None:\n",
    "            times = [None] * len(obs)\n",
    "        for o, t in zip(obs, times):\n",
    "            o, t = to_np(o), to_np(t)\n",
    "            for b_i in range(o.shape[1]):\n",
    "                plt.scatter(o[:, b_i, 0], o[:, b_i, 1], c=t[:, b_i, 0], cmap=cm.plasma)\n",
    "\n",
    "    if trajs is not None: \n",
    "        for z in trajs:\n",
    "            z = to_np(z)\n",
    "            plt.plot(z[:, 0, 0], z[:, 0, 1], lw=1.5)\n",
    "        if save is not None:\n",
    "            plt.savefig(save)\n",
    "    plt.show()\n",
    "\n",
    "def conduct_experiment(ode_true, ode_trained, n_steps, name, plot_freq=10):\n",
    "    # Create data\n",
    "    z0 = Variable(torch.Tensor([[0.6, 0.3]]))\n",
    "\n",
    "    t_max = 6.29*5\n",
    "    n_points = 200\n",
    "\n",
    "    index_np = np.arange(0, n_points, 1, dtype=np.int)\n",
    "    index_np = np.hstack([index_np[:, None]])\n",
    "    times_np = np.linspace(0, t_max, num=n_points)\n",
    "    times_np = np.hstack([times_np[:, None]])\n",
    "\n",
    "    times = torch.from_numpy(times_np[:, :, None]).to(z0)\n",
    "    obs = ode_true(z0, times, return_whole_sequence=True).detach()\n",
    "    obs = obs + torch.randn_like(obs) * 0.01\n",
    "\n",
    "    # Get trajectory of random timespan \n",
    "    min_delta_time = 1.0\n",
    "    max_delta_time = 5.0\n",
    "    max_points_num = 32\n",
    "    def create_batch():\n",
    "        t0 = np.random.uniform(0, t_max - max_delta_time)\n",
    "        t1 = t0 + np.random.uniform(min_delta_time, max_delta_time)\n",
    "\n",
    "        idx = sorted(np.random.permutation(index_np[(times_np >; t0) &; (times_np <; t1)])[:max_points_num])\n",
    "\n",
    "        obs_ = obs[idx]\n",
    "        ts_ = times[idx]\n",
    "        return obs_, ts_\n",
    "\n",
    "    # Train Neural ODE\n",
    "    optimizer = torch.optim.Adam(ode_trained.parameters(), lr=0.01)\n",
    "    for i in range(n_steps):\n",
    "        obs_, ts_ = create_batch()\n",
    "\n",
    "        z_ = ode_trained(obs_[0], ts_, return_whole_sequence=True)\n",
    "        loss = F.mse_loss(z_, obs_.detach())\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward(retain_graph=True)\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % plot_freq == 0:\n",
    "            z_p = ode_trained(z0, times, return_whole_sequence=True)\n",
    "\n",
    "            plot_trajectories(obs=[obs], times=[times], trajs=[z_p], save=f\"assets/imgs/{name}/{i}.png\")\n",
    "            clear_output(wait=True)\n",
    "\n",
    "ode_true = NeuralODE(SpiralFunctionExample())\n",
    "ode_trained = NeuralODE(RandomLinearODEF())\n",
    "\n",
    "conduct_experiment(ode_true, ode_trained, 500, \"linear\")\n",
    "\n",
    "func = TestODEF(Tensor([[-0.1, -0.5], [0.5, -0.1]]), Tensor([[0.2, 1.], [-1, 0.2]]), Tensor([[-1., 0.]]))\n",
    "ode_true = NeuralODE(func)\n",
    "\n",
    "func = NNODEF(2, 16, time_invariant=True)\n",
    "ode_trained = NeuralODE(func)\n",
    "\n",
    "conduct_experiment(ode_true, ode_trained, 3000, \"comp\", plot_freq=30)\n",
    "</source>\n",
    "Как можно видеть, <em>Neural ODE</em> довольно хорошо справляются с восстановлением динамики. То есть концепция в целом работает.<br>\n",
    "Теперь проверим на чуть более сложной задаче (MNIST, ха-ха).\n",
    "\n",
    "<h2>Neural ODE вдохновленные ResNets</h2>\n",
    "В ResNet’ax скрытое состояние меняется по формуле\n",
    "<img src=\"https://tex.s2cms.ru/svg/%0Ah_%7Bt%2B1%7D%20%3D%20h_%7Bt%7D%20%2B%20f(h_%7Bt%7D%2C%20%5Ctheta_%7Bt%7D)%0A\" alt=\"\n",
    "h_{t+1} = h_{t} + f(h_{t}, \\theta_{t})\n",
    "\" />\n",
    "\n",
    "где <img src=\"https://tex.s2cms.ru/svg/t%20%5Cin%20%5C%7B0...T%5C%7D\" alt=\"t \\in \\{0...T\\}\" /> - это номер блока и <img src=\"https://tex.s2cms.ru/svg/f\" alt=\"f\" /> это функция выучиваемая слоями внутри блока.\n",
    "\n",
    "В пределе, если брать бесконечное число блоков со все меньшими шагами, мы получаем непрерывную динамику скрытого слоя в виде ОДУ, прямо как то, что было выше.\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7Bdh(t)%7D%7Bdt%7D%20%3D%20f(h(t)%2C%20t%2C%20%5Ctheta)%0A\" alt=\"\n",
    "\\frac{dh(t)}{dt} = f(h(t), t, \\theta)\n",
    "\" />\n",
    "Начиная со входного слоя <img src=\"https://tex.s2cms.ru/svg/h(0)\" alt=\"h(0)\" />, мы можем определить выходной слой <img src=\"https://tex.s2cms.ru/svg/h(T)\" alt=\"h(T)\" /> как решение этого ОДУ в момент времени T.\n",
    "\n",
    "Теперь мы можем считать <img src=\"https://tex.s2cms.ru/svg/%5Ctheta\" alt=\"\\theta\" /> как распределенные (<em>shared</em>) параметры между всеми бесконечно малыми блоками.\n",
    "\n",
    "<h3>Проверка Neural ODE архитектуры на MNIST</h3>\n",
    "В этой части мы проверим возможность <em>Neural ODE</em> быть использованными в виде компонентов в более привычных архитектурах.\n",
    "\n",
    "В частности, мы заменим остаточные (<em>residual</em>) блоки на <em>Neural ODE</em> в классификаторе MNIST.\n",
    "<img src=\"https://habrastorage.org/webt/7c/gw/ia/7cgwiawqx6-_kxk82qpenqplowi.png\" width=400 />\n",
    "Код ниже (спойлер)\n",
    "\n",
    "<source lang=\"python\">def norm(dim):\n",
    "    return nn.BatchNorm2d(dim)\n",
    "\n",
    "def conv3x3(in_feats, out_feats, stride=1):\n",
    "    return nn.Conv2d(in_feats, out_feats, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "\n",
    "def add_time(in_tensor, t):\n",
    "    bs, c, w, h = in_tensor.shape\n",
    "    return torch.cat((in_tensor, t.expand(bs, 1, w, h)), dim=1)\n",
    "\n",
    "class ConvODEF(ODEF):\n",
    "    def __init__(self, dim):\n",
    "        super(ConvODEF, self).__init__()\n",
    "        self.conv1 = conv3x3(dim + 1, dim)\n",
    "        self.norm1 = norm(dim)\n",
    "        self.conv2 = conv3x3(dim + 1, dim)\n",
    "        self.norm2 = norm(dim)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        xt = add_time(x, t)\n",
    "        h = self.norm1(torch.relu(self.conv1(xt)))\n",
    "        ht = add_time(h, t)\n",
    "        dxdt = self.norm2(torch.relu(self.conv2(ht)))\n",
    "        return dxdt\n",
    "\n",
    "class ContinuousNeuralMNISTClassifier(nn.Module):\n",
    "    def __init__(self, ode):\n",
    "        super(ContinuousNeuralMNISTClassifier, self).__init__()\n",
    "        self.downsampling = nn.Sequential(\n",
    "            nn.Conv2d(1, 64, 3, 1),\n",
    "            norm(64),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(64, 64, 4, 2, 1),\n",
    "            norm(64),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(64, 64, 4, 2, 1),\n",
    "        )\n",
    "        self.feature = ode\n",
    "        self.norm = norm(64)\n",
    "        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "        self.fc = nn.Linear(64, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.downsampling(x)\n",
    "        x = self.feature(x)\n",
    "        x = self.norm(x)\n",
    "        x = self.avg_pool(x)\n",
    "        shape = torch.prod(torch.tensor(x.shape[1:])).item()\n",
    "        x = x.view(-1, shape)\n",
    "        out = self.fc(x)\n",
    "        return out\n",
    "\n",
    "func = ConvODEF(64)\n",
    "ode = NeuralODE(func)\n",
    "model = ContinuousNeuralMNISTClassifier(ode)\n",
    "if use_cuda:\n",
    "    model = model.cuda()\n",
    "\n",
    "import torchvision\n",
    "\n",
    "img_std = 0.3081\n",
    "img_mean = 0.1307\n",
    "\n",
    "batch_size = 32\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    torchvision.datasets.MNIST(\"data/mnist\", train=True, download=True,\n",
    "                             transform=torchvision.transforms.Compose([\n",
    "                                 torchvision.transforms.ToTensor(),\n",
    "                                 torchvision.transforms.Normalize((img_mean,), (img_std,))\n",
    "                             ])\n",
    "    ),\n",
    "    batch_size=batch_size, shuffle=True\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    torchvision.datasets.MNIST(\"data/mnist\", train=False, download=True,\n",
    "                             transform=torchvision.transforms.Compose([\n",
    "                                 torchvision.transforms.ToTensor(),\n",
    "                                 torchvision.transforms.Normalize((img_mean,), (img_std,))\n",
    "                             ])\n",
    "    ),\n",
    "    batch_size=128, shuffle=True\n",
    ")\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "def train(epoch):\n",
    "    num_items = 0\n",
    "    train_losses = []\n",
    "\n",
    "    model.train()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    print(f\"Training Epoch {epoch}...\")\n",
    "    for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):\n",
    "        if use_cuda:\n",
    "            data = data.cuda()\n",
    "            target = target.cuda()\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = criterion(output, target) \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_losses += [loss.item()]\n",
    "        num_items += data.shape[0]\n",
    "    print('Train loss: {:.5f}'.format(np.mean(train_losses)))\n",
    "    return train_losses\n",
    "\n",
    "def test():\n",
    "    accuracy = 0.0\n",
    "    num_items = 0\n",
    "\n",
    "    model.eval()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    print(f\"Testing...\")\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (data, target) in tqdm(enumerate(test_loader),  total=len(test_loader)):\n",
    "            if use_cuda:\n",
    "                data = data.cuda()\n",
    "                target = target.cuda()\n",
    "            output = model(data)\n",
    "            accuracy += torch.sum(torch.argmax(output, dim=1) == target).item()\n",
    "            num_items += data.shape[0]\n",
    "    accuracy = accuracy * 100 / num_items\n",
    "    print(\"Test Accuracy: {:.3f}%\".format(accuracy))\n",
    "\n",
    "n_epochs = 5\n",
    "test()\n",
    "train_losses = []\n",
    "for epoch in range(1, n_epochs + 1):\n",
    "    train_losses += train(epoch)\n",
    "    test()\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "plt.figure(figsize=(9, 5))\n",
    "history = pd.DataFrame({\"loss\": train_losses})\n",
    "history[\"cum_data\"] = history.index * batch_size\n",
    "history[\"smooth_loss\"] = history.loss.ewm(halflife=10).mean()\n",
    "history.plot(x=\"cum_data\", y=\"smooth_loss\", figsize=(12, 5), title=\"train error\")\n",
    "</source>\n",
    "\n",
    "<source>Testing...\n",
    "100% 79/79 [00:01<;00:00, 45.69it/s]\n",
    "Test Accuracy: 9.740%\n",
    "\n",
    "Training Epoch 1...\n",
    "100% 1875/1875 [01:15<;00:00, 24.69it/s]\n",
    "Train loss: 0.20137\n",
    "Testing...\n",
    "100% 79/79 [00:01<;00:00, 46.64it/s]\n",
    "Test Accuracy: 98.680%\n",
    "\n",
    "Training Epoch 2...\n",
    "100% 1875/1875 [01:17<;00:00, 24.32it/s]\n",
    "Train loss: 0.05059\n",
    "Testing...\n",
    "100% 79/79 [00:01<;00:00, 46.11it/s]\n",
    "Test Accuracy: 97.760%\n",
    "\n",
    "Training Epoch 3...\n",
    "100% 1875/1875 [01:16<;00:00, 24.63it/s]\n",
    "Train loss: 0.03808\n",
    "Testing...\n",
    "100% 79/79 [00:01<;00:00, 45.65it/s]\n",
    "Test Accuracy: 99.000%\n",
    "\n",
    "Training Epoch 4...\n",
    "100% 1875/1875 [01:17<;00:00, 24.28it/s]\n",
    "Train loss: 0.02894\n",
    "Testing...\n",
    "100% 79/79 [00:01<;00:00, 45.42it/s]\n",
    "Test Accuracy: 99.130%\n",
    "\n",
    "Training Epoch 5...\n",
    "100% 1875/1875 [01:16<;00:00, 24.67it/s]\n",
    "Train loss: 0.02424\n",
    "Testing...\n",
    "100% 79/79 [00:01<;00:00, 45.89it/s]\n",
    "Test Accuracy: 99.170%\n",
    "</source>\n",
    "<img src=\"assets/train_error.png\" alt=\"train error\">\n",
    "\n",
    "После очень грубой тренировки в течение всего 5 эпох и 6 минут обучения, модель уже достигла тестовой ошибки в менее, чем 1%. Можно сказать, что <em>Нейронные ОДУ</em> хорошо интегрируются <em>в виде</em> компонента в более традиционные сети.\n",
    "\n",
    "В своей статье авторы также сравнивают этот классификатор (ODE-Net) с обычной полнозвязной сетью, с ResNet’ом с похожей архитектурой, и с точно такой же архитектурой, только в которой градиент распространяется напрямую через операции в <em>ODESolve</em> (без метода сопряженного градиента) (RK-Net).\n",
    "<img src=\"https://habrastorage.org/webt/zx/zv/5x/zxzv5x4ktqyy9lt19of-d2i4few.png\" />\n",
    "<div align=\"center\">Иллюстрация из оригинальной статьи</div>\n",
    "Согласно им, 1-слойная полносвязноая сеть с примерно таким же количеством параметров как <em>Neural ODE</em> имеет намного более высокую ошибку на тесте, ResNet с примерно такой же ошибкой имеет намного больше параметров, а RK-Net без метода сопряженного градиента, имеет чуть более высокую ошибку и с линейно растущим потреблением памяти (чем меньше допустимая ошибка, тем больше шагов должен сделать <em>ODESolve</em>, что линейно увеличивает потребляемую память с числом шагов).\n",
    "\n",
    "Авторы в своей имплементации используют неявный метод Рунге-Кутты с адаптивным размером шага, в отличие от простейшего метода Эйлера здесь. Они также изучают некоторые свойства новой архитектуры.\n",
    "<img src=\"https://habrastorage.org/webt/u7/li/fq/u7lifqdsf72otgcyfigm6-fiyww.png\" />\n",
    "<div align=\"center\">Характеристика ODE-Net (NFE Forward - количество вычислений функции при прямом проходе)</div>\n",
    "<div align=\"center\">Иллюстрация из оригинальной статьи</div>\n",
    "<ul>\n",
    "<li>(a) Изменение допустимого уровня численной ошибки изменяет количество шагов в прямом распространении.\n",
    "</li>\n",
    "<li>(b) Время потраченное на прямоу распространение пропорционально количеству вычеслений функции.\n",
    "</li>\n",
    "<li>© Количество вычислений функции при обратном распространение составляет примерно половину от прямого распространения, это указывает на то, что метод сопряженного градиента может быть более вычислительно эффективным, чем распространение градиента напрямую через <em>ODESolve</em>.\n",
    "</li>\n",
    "<li>(d) Как ODE-Net становится все более и более обученным, он требует все больше вычислений функции (все меньший шаг), возможно адаптируясь под восрастающую сложность модели.\n",
    "</li>\n",
    "</ul>\n",
    "<h2>Скрытая генеративная функция для моделирования временного ряда</h2>\n",
    "Neural ODE подходит для обработки непрерывных последовательных данных и тогда, когда траектория лежит в неизвестном скрытом пространстве.\n",
    "\n",
    "В этом разделе мы поэкспер<em>и</em>ментируем в генерации непрерывных последовательностей, используя <em>Neural ODE</em>, и немножко посмотрим на выученное скрытое пространство.\n",
    "\n",
    "Авторы также сравнивают это с аналогичными последовательност<em>ями</em>, сгенерированными Рекуррентными сетями.\n",
    "\n",
    "Эксперимент здесь слегка отличается от соответствующего примера в репозитории авторов, здесь более разнообразное множество траекторий.\n",
    "\n",
    "<h3>Данные</h3>\n",
    "Обучающие данные состоят из случайных спиралей, половина из которых направлены по-часовой, а вторая - против часовой. Далее случайные подпоследовательности сэмплируются из этих спиралей, обрабатываются кодирующей рекуррентной моделью в обратном направлении, порождая стартовое скрытое состояние, которое затем эволюционирует, создавая траекторию в скрытом пространстве. Это скрытая траектория затем отображается в пространство данных и сравнивается с сэмплированной подпоследовательностью. Таким образом, модель учится генерировать траектории, похожие на датасет.\n",
    "<img src=\"https://habrastorage.org/webt/or/4z/lg/or4zlgwkqjm-kzhvlmqtzmeqnl4.png\" />\n",
    "<div align=\"center\">Примеры спиралей из датасета</div>\n",
    "<h3>VAE как генеративная модель</h3>\n",
    "Генеративная модель через процедуру сэмплирования:\n",
    "<img src=\"https://tex.s2cms.ru/svg/%0Az_%7Bt_0%7D%20%5Csim%20%5Cmathcal%7BN%7D(0%2C%20I)%0A\" alt=\"\n",
    "z_{t_0} \\sim \\mathcal{N}(0, I)\n",
    "\" />\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0Az_%7Bt_1%7D%2C%20z_%7Bt_2%7D%2C...%2Cz_%7Bt_M%7D%20%3D%20%5Ctext%7BODESolve%7D(z_%7Bt_0%7D%2C%20f%2C%20%5Ctheta_f%2C%20t_0%2C...%2Ct_M)%0A\" alt=\"\n",
    "z_{t_1}, z_{t_2},...,z_{t_M} = \\text{ODESolve}(z_{t_0}, f, \\theta_f, t_0,...,t_M)\n",
    "\" />\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0Ax_%7Bt_i%7D%20%5Csim%20p(x%20%5Cmid%20z_%7Bt_i%7D%3B%5Ctheta_x)%0A\" alt=\"\n",
    "x_{t_i} \\sim p(x \\mid z_{t_i};\\theta_x)\n",
    "\" />\n",
    "Которая может быть обучена, используя подход вариационных автокодировщиков.\n",
    "<ol>\n",
    "<li>Пройтись рекуррентным энкодером через временную последовательность назад во времени, чтобы получить параметры <img src=\"https://tex.s2cms.ru/svg/%5Cmu_%7Bz_%7Bt_0%7D%7D\" alt=\"\\mu_{z_{t_0}}\" />, <img src=\"https://tex.s2cms.ru/svg/%5Csigma_%7Bz_%7Bt_0%7D%7D\" alt=\"\\sigma_{z_{t_0}}\" /> вариационного апестериорного распределения, а потом сэмплировать из него:\n",
    "</li>\n",
    "</ol>\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0Az_%7Bt_0%7D%20%5Csim%20q%20%5Cleft(%20z_%7Bt_0%7D%20%5Cmid%20x_%7Bt_0%7D%2C...%2Cx_%7Bt_M%7D%3B%20t_0%2C...%2Ct_M%3B%20%5Ctheta_q%20%5Cright)%20%3D%20%5Cmathcal%7BN%7D%20%5Cleft(z_%7Bt_0%7D%20%5Cmid%20%5Cmu_%7Bz_%7Bt_0%7D%7D%20%5Csigma_%7Bz_%7Bt_0%7D%7D%20%5Cright)%0A\" alt=\"\n",
    "z_{t_0} \\sim q \\left( z_{t_0} \\mid x_{t_0},...,x_{t_M}; t_0,...,t_M; \\theta_q \\right) = \\mathcal{N} \\left(z_{t_0} \\mid \\mu_{z_{t_0}} \\sigma_{z_{t_0}} \\right)\n",
    "\" />\n",
    "<ol start=\"2\">\n",
    "<li>Получить скрытую траекторию:\n",
    "</li>\n",
    "</ol>\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0Az_%7Bt_1%7D%2C%20z_%7Bt_2%7D%2C...%2Cz_%7Bt_N%7D%20%3D%20%5Ctext%7BODESolve%7D(z_%7Bt_0%7D%2C%20f%2C%20%5Ctheta_f%2C%20t_0%2C...%2Ct_N)%2C%20%5Ctext%7B%20%D0%B3%D0%B4%D0%B5%20%7D%20%5Cfrac%7Bd%20z%7D%7Bd%20t%7D%20%3D%20f(z%2C%20t%3B%20%5Ctheta_f)%0A\" alt=\"\n",
    "z_{t_1}, z_{t_2},...,z_{t_N} = \\text{ODESolve}(z_{t_0}, f, \\theta_f, t_0,...,t_N), \\text{ где } \\frac{d z}{d t} = f(z, t; \\theta_f)\n",
    "\" />\n",
    "<ol start=\"3\">\n",
    "<li>Отобразить скрытую траекторию в траекторию в данных, используя другую нейросеть: <img src=\"https://tex.s2cms.ru/svg/%5Chat%7Bx_%7Bt_i%7D%7D(z_%7Bt_i%7D%2C%20t_i%3B%20%5Ctheta_x)\" alt=\"\\hat{x_{t_i}}(z_{t_i}, t_i; \\theta_x)\" />\n",
    "</li>\n",
    "<li>Максимизировать оценку нижней границы обоснованности (ELBO) для сэмплированной траектории:\n",
    "</li>\n",
    "</ol>\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Ctext%7BELBO%7D%20%5Capprox%20N%20%5CBig(%20%5Csum_%7Bi%3D0%7D%5E%7BM%7D%20%5Clog%20p(x_%7Bt_i%7D%20%5Cmid%20z_%7Bt_i%7D(z_%7Bt_0%7D%3B%20%5Ctheta_f)%3B%20%5Ctheta_x)%20%2B%20KL%20%5Cleft(%20q(%20z_%7Bt_0%7D%20%5Cmid%20x_%7Bt_0%7D%2C...%2Cx_%7Bt_M%7D%3B%20t_0%2C...%2Ct_M%3B%20%5Ctheta_q)%20%5Cparallel%20%5Cmathcal%7BN%7D(0%2C%20I)%20%5Cright)%20%5CBig)%0A\" alt=\"\n",
    "\\text{ELBO} \\approx N \\Big( \\sum_{i=0}^{M} \\log p(x_{t_i} \\mid z_{t_i}(z_{t_0}; \\theta_f); \\theta_x) + KL \\left( q( z_{t_0} \\mid x_{t_0},...,x_{t_M}; t_0,...,t_M; \\theta_q) \\parallel \\mathcal{N}(0, I) \\right) \\Big)\n",
    "\" />\n",
    "И в случае Гауссовского апостериорного распределения <img src=\"https://tex.s2cms.ru/svg/p(x%20%5Cmid%20z_%7Bt_i%7D%3B%5Ctheta_x)\" alt=\"p(x \\mid z_{t_i};\\theta_x)\" /> и известного уровня шума <img src=\"https://tex.s2cms.ru/svg/%5Csigma_x\" alt=\"\\sigma_x\" />:\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Ctext%7BELBO%7D%20%5Capprox%20-N%20%5CBig(%20%5Csum_%7Bi%3D1%7D%5E%7BM%7D%5Cfrac%7B(x_i%20-%20%5Chat%7Bx_i%7D%20)%5E2%7D%7B%5Csigma_x%5E2%7D%20-%20%5Clog%20%5Csigma_%7Bz_%7Bt_0%7D%7D%5E2%20%2B%20%5Cmu_%7Bz_%7Bt_0%7D%7D%5E2%20%2B%20%5Csigma_%7Bz_%7Bt_0%7D%7D%5E2%20%5CBig)%20%2B%20C%0A\" alt=\"\n",
    "\\text{ELBO} \\approx -N \\Big( \\sum_{i=1}^{M}\\frac{(x_i - \\hat{x_i} )^2}{\\sigma_x^2} - \\log \\sigma_{z_{t_0}}^2 + \\mu_{z_{t_0}}^2 + \\sigma_{z_{t_0}}^2 \\Big) + C\n",
    "\" />\n",
    "Граф вычислений скрытой ОДУ модель (<em>модели?</em>) можно изобразить вот так\n",
    "<img src=\"https://habrastorage.org/webt/zl/57/ui/zl57uisgnvjy7-y94fkker9m1eq.png\" />\n",
    "<div align=\"center\">Иллюстрация из оригинальной статьи</div>\n",
    "Эту модель можно затем протестировать на то, как она интерполирует траекторию, используя только начальные наблюдения.\n",
    "\n",
    "Код ниже (спойлер)\n",
    "\n",
    "<h3>Define models</h3>\n",
    "\n",
    "<source lang=\"python\">class RNNEncoder(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, latent_dim):\n",
    "        super(RNNEncoder, self).__init__()\n",
    "        self.input_dim = input_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "        self.rnn = nn.GRU(input_dim+1, hidden_dim)\n",
    "        self.hid2lat = nn.Linear(hidden_dim, 2*latent_dim)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        # Concatenate time to input\n",
    "        t = t.clone()\n",
    "        t[1:] = t[:-1] - t[1:]\n",
    "        t[0] = 0.\n",
    "        xt = torch.cat((x, t), dim=-1)\n",
    "\n",
    "        _, h0 = self.rnn(xt.flip((0,)))  # Reversed\n",
    "        # Compute latent dimension\n",
    "        z0 = self.hid2lat(h0[0])\n",
    "        z0_mean = z0[:, :self.latent_dim]\n",
    "        z0_log_var = z0[:, self.latent_dim:]\n",
    "        return z0_mean, z0_log_var\n",
    "\n",
    "class NeuralODEDecoder(nn.Module):\n",
    "    def __init__(self, output_dim, hidden_dim, latent_dim):\n",
    "        super(NeuralODEDecoder, self).__init__()\n",
    "        self.output_dim = output_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "        func = NNODEF(latent_dim, hidden_dim, time_invariant=True)\n",
    "        self.ode = NeuralODE(func)\n",
    "        self.l2h = nn.Linear(latent_dim, hidden_dim)\n",
    "        self.h2o = nn.Linear(hidden_dim, output_dim)\n",
    "\n",
    "    def forward(self, z0, t):\n",
    "        zs = self.ode(z0, t, return_whole_sequence=True)\n",
    "\n",
    "        hs = self.l2h(zs)\n",
    "        xs = self.h2o(hs)\n",
    "        return xs\n",
    "\n",
    "class ODEVAE(nn.Module):\n",
    "    def __init__(self, output_dim, hidden_dim, latent_dim):\n",
    "        super(ODEVAE, self).__init__()\n",
    "        self.output_dim = output_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "        self.encoder = RNNEncoder(output_dim, hidden_dim, latent_dim)\n",
    "        self.decoder = NeuralODEDecoder(output_dim, hidden_dim, latent_dim)\n",
    "\n",
    "    def forward(self, x, t, MAP=False):\n",
    "        z_mean, z_log_var = self.encoder(x, t)\n",
    "        if MAP:\n",
    "            z = z_mean\n",
    "        else:\n",
    "            z = z_mean + torch.randn_like(z_mean) * torch.exp(0.5 * z_log_var)\n",
    "        x_p = self.decoder(z, t)\n",
    "        return x_p, z, z_mean, z_log_var\n",
    "\n",
    "    def generate_with_seed(self, seed_x, t):\n",
    "        seed_t_len = seed_x.shape[0]\n",
    "        z_mean, z_log_var = self.encoder(seed_x, t[:seed_t_len])\n",
    "        x_p = self.decoder(z_mean, t)\n",
    "        return x_p\n",
    "</source>\n",
    "<h3>Генерация датасета</h3>\n",
    "\n",
    "<source lang=\"python\">t_max = 6.29*5\n",
    "n_points = 200\n",
    "noise_std = 0.02\n",
    "\n",
    "num_spirals = 1000\n",
    "\n",
    "index_np = np.arange(0, n_points, 1, dtype=np.int)\n",
    "index_np = np.hstack([index_np[:, None]])\n",
    "times_np = np.linspace(0, t_max, num=n_points)\n",
    "times_np = np.hstack([times_np[:, None]] * num_spirals)\n",
    "times = torch.from_numpy(times_np[:, :, None]).to(torch.float32)\n",
    "\n",
    "# Generate random spirals parameters\n",
    "normal01 = torch.distributions.Normal(0, 1.0)\n",
    "\n",
    "x0 = Variable(normal01.sample((num_spirals, 2))) * 2.0  \n",
    "\n",
    "W11 = -0.1 * normal01.sample((num_spirals,)).abs() - 0.05\n",
    "W22 = -0.1 * normal01.sample((num_spirals,)).abs() - 0.05\n",
    "W21 = -1.0 * normal01.sample((num_spirals,)).abs()\n",
    "W12 =  1.0 * normal01.sample((num_spirals,)).abs()\n",
    "\n",
    "xs_list = []\n",
    "for i in range(num_spirals):\n",
    "    if i % 2 == 1: #  Make it counter-clockwise\n",
    "        W21, W12 = W12, W21\n",
    "\n",
    "    func = LinearODEF(Tensor([[W11[i], W12[i]], [W21[i], W22[i]]]))\n",
    "    ode = NeuralODE(func)\n",
    "\n",
    "    xs = ode(x0[i:i+1], times[:, i:i+1], return_whole_sequence=True)\n",
    "    xs_list.append(xs)\n",
    "\n",
    "\n",
    "orig_trajs = torch.cat(xs_list, dim=1).detach()\n",
    "samp_trajs = orig_trajs + torch.randn_like(orig_trajs) * noise_std\n",
    "samp_ts = times\n",
    "\n",
    "fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(15, 9))\n",
    "axes = axes.flatten()\n",
    "for i, ax in enumerate(axes):\n",
    "    ax.scatter(samp_trajs[:, i, 0], samp_trajs[:, i, 1], c=samp_ts[:, i, 0], cmap=cm.plasma)\n",
    "plt.show()\n",
    "\n",
    "import numpy.random as npr\n",
    "\n",
    "def gen_batch(batch_size, n_sample=100):\n",
    "    n_batches = samp_trajs.shape[1] // batch_size\n",
    "    time_len = samp_trajs.shape[0]\n",
    "    n_sample = min(n_sample, time_len)\n",
    "    for i in range(n_batches):\n",
    "        if n_sample >; 0:\n",
    "            t0_idx = npr.multinomial(1, [1. / (time_len - n_sample)] * (time_len - n_sample))\n",
    "            t0_idx = np.argmax(t0_idx)\n",
    "            tM_idx = t0_idx + n_sample\n",
    "        else:\n",
    "            t0_idx = 0\n",
    "            tM_idx = time_len\n",
    "\n",
    "        frm, to = batch_size*i, batch_size*(i+1)\n",
    "        yield samp_trajs[t0_idx:tM_idx, frm:to], samp_ts[t0_idx:tM_idx, frm:to]\n",
    "</source>\n",
    "<h3>Обучение</h3>\n",
    "\n",
    "<source lang=\"python\">vae = ODEVAE(2, 64, 6)\n",
    "vae = vae.cuda()\n",
    "if use_cuda:\n",
    "    vae = vae.cuda()\n",
    "\n",
    "optim = torch.optim.Adam(vae.parameters(), betas=(0.9, 0.999), lr=0.001)\n",
    "\n",
    "preload = False\n",
    "n_epochs = 20000\n",
    "batch_size = 100\n",
    "\n",
    "plot_traj_idx = 1\n",
    "plot_traj = orig_trajs[:, plot_traj_idx:plot_traj_idx+1]\n",
    "plot_obs = samp_trajs[:, plot_traj_idx:plot_traj_idx+1]\n",
    "plot_ts = samp_ts[:, plot_traj_idx:plot_traj_idx+1]\n",
    "if use_cuda:\n",
    "    plot_traj = plot_traj.cuda()\n",
    "    plot_obs = plot_obs.cuda()\n",
    "    plot_ts = plot_ts.cuda()\n",
    "\n",
    "if preload:\n",
    "    vae.load_state_dict(torch.load(\"models/vae_spirals.sd\"))\n",
    "\n",
    "for epoch_idx in range(n_epochs):\n",
    "    losses = []\n",
    "    train_iter = gen_batch(batch_size)\n",
    "    for x, t in train_iter:\n",
    "        optim.zero_grad()\n",
    "        if use_cuda:\n",
    "            x, t = x.cuda(), t.cuda()\n",
    "\n",
    "        max_len = np.random.choice([30, 50, 100])\n",
    "        permutation = np.random.permutation(t.shape[0])\n",
    "        np.random.shuffle(permutation)\n",
    "        permutation = np.sort(permutation[:max_len])\n",
    "\n",
    "        x, t = x[permutation], t[permutation]\n",
    "\n",
    "        x_p, z, z_mean, z_log_var = vae(x, t)\n",
    "        kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean**2 - torch.exp(z_log_var), -1)\n",
    "        loss = 0.5 * ((x-x_p)**2).sum(-1).sum(0) / noise_std**2 + kl_loss\n",
    "        loss = torch.mean(loss)\n",
    "        loss /= max_len\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        losses.append(loss.item())\n",
    "\n",
    "    print(f\"Epoch {epoch_idx}\")\n",
    "\n",
    "    frm, to, to_seed = 0, 200, 50\n",
    "    seed_trajs = samp_trajs[frm:to_seed]\n",
    "    ts = samp_ts[frm:to]\n",
    "    if use_cuda:\n",
    "        seed_trajs = seed_trajs.cuda()\n",
    "        ts = ts.cuda()\n",
    "\n",
    "    samp_trajs_p = to_np(vae.generate_with_seed(seed_trajs, ts))\n",
    "\n",
    "    fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(15, 9))\n",
    "    axes = axes.flatten()\n",
    "    for i, ax in enumerate(axes):\n",
    "        ax.scatter(to_np(seed_trajs[:, i, 0]), to_np(seed_trajs[:, i, 1]), c=to_np(ts[frm:to_seed, i, 0]), cmap=cm.plasma)\n",
    "        ax.plot(to_np(orig_trajs[frm:to, i, 0]), to_np(orig_trajs[frm:to, i, 1]))\n",
    "        ax.plot(samp_trajs_p[:, i, 0], samp_trajs_p[:, i, 1])\n",
    "    plt.show()\n",
    "\n",
    "    print(np.mean(losses), np.median(losses))\n",
    "    clear_output(wait=True)\n",
    "\n",
    "spiral_0_idx = 3\n",
    "spiral_1_idx = 6\n",
    "\n",
    "homotopy_p = Tensor(np.linspace(0., 1., 10)[:, None])\n",
    "vae = vae\n",
    "if use_cuda:\n",
    "    homotopy_p = homotopy_p.cuda()\n",
    "    vae = vae.cuda()\n",
    "\n",
    "spiral_0 = orig_trajs[:, spiral_0_idx:spiral_0_idx+1, :]\n",
    "spiral_1 = orig_trajs[:, spiral_1_idx:spiral_1_idx+1, :]\n",
    "ts_0 = samp_ts[:, spiral_0_idx:spiral_0_idx+1, :]\n",
    "ts_1 = samp_ts[:, spiral_1_idx:spiral_1_idx+1, :]\n",
    "if use_cuda:\n",
    "    spiral_0, ts_0 = spiral_0.cuda(), ts_0.cuda()\n",
    "    spiral_1, ts_1 = spiral_1.cuda(), ts_1.cuda()\n",
    "\n",
    "z_cw, _ = vae.encoder(spiral_0, ts_0)\n",
    "z_cc, _ = vae.encoder(spiral_1, ts_1)\n",
    "\n",
    "homotopy_z = z_cw * (1 - homotopy_p) + z_cc * homotopy_p\n",
    "\n",
    "t = torch.from_numpy(np.linspace(0, 6*np.pi, 200))\n",
    "t = t[:, None].expand(200, 10)[:, :, None].cuda()\n",
    "t = t.cuda() if use_cuda else t\n",
    "hom_gen_trajs = vae.decoder(homotopy_z, t)\n",
    "\n",
    "fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(15, 5))\n",
    "axes = axes.flatten()\n",
    "for i, ax in enumerate(axes):\n",
    "    ax.plot(to_np(hom_gen_trajs[:, i, 0]), to_np(hom_gen_trajs[:, i, 1]))\n",
    "plt.show()\n",
    "\n",
    "torch.save(vae.state_dict(), \"models/vae_spirals.sd\")\n",
    "</source>\n",
    "Вот что получается после ночи обучения\n",
    "<img src=\"https://habrastorage.org/webt/fv/pn/pi/fvpnpimixf3sq0cezhepgzh2txy.png\" />\n",
    "<div align=\"center\"> Точки - это зашумленные наблюдения оригинальной траектории (синий), <br /> желтая - это реконструированная и интерполированная траектория, используя точки как входы. <br /> Цвет точки показывает время. </div>\n",
    "Реконструкции некоторых примеров не выглядят слишком хорошими. Может модель недостаточно сложная или недостаточно долго училась. В любом случае реконструкции выглядят очень разумно.\n",
    "\n",
    "Теперь посмотрим что будет, если интерполировать скрытую переменную по-часовой траектории к противо-часовой траектории.\n",
    "<img src=\"https://habrastorage.org/webt/at/g_/7e/atg_7ecofins-uohnv99qccfxfq.png\" />\n",
    "\n",
    "Авторы также сравнивают реконструкции и интерполяции траекторий между <em>Neural ODE</em> и простой Рекуррентной сетью.\n",
    "<img src=\"https://habrastorage.org/webt/kj/ak/-e/kjak-evgr-7mrrtpgimf0gfmfxq.png\" />\n",
    "<div align=\"center\">Иллюстрация из оригинальной статьи</div>\n",
    "<h2>Непрерывные Нормализующие Потоки</h2>\n",
    "Оригинальная статья также привносит многое в тему Нормализующих Потоков. Нормализующие потоки используются, когда нужно сэмплировать из некоторого сложного распределения, появившегося через замену переменных от некоторого простого распределения (Гауссовского, например), и при этом все еще знать плотность вероятности в точке каждого сэмпла.\n",
    "Авторы показывают, что использование непрерывной замены переменных намного более вычислительно эффективно и интерпретируемо, чем предыдущие методы.\n",
    "\n",
    "<em>Нормализующие потоки</em> очень полезны в таких моделях как <em>Вариационные Автокодировщики</em>, <em>Байесовские Нейронные Сети</em> и других из Баейсовского подхода.\n",
    "\n",
    "Эта тема, впрочем, лежит за пределами <em>данной</em> статьи, и тем, кто заинтересовался, следует прочесть оригинальную научную статью.\n",
    "\n",
    "Для затравки:\n",
    "<img src=\"https://habrastorage.org/webt/nd/e_/wn/nde_wn590scz9dff_0hzxo1-tgg.png\" />\n",
    "<div align=\"center\">Визуализация трансформации из шума (простого распределения) в данные (сложное распределение) для двух датасетов; <br /> Ось-X показывает трансформацию плотности и сэмплов с течением \"времени\" (для ННП) и \"глубины\" (для НП). <br />Иллюстрация из оригинальной статьи</div>\n",
    "Спасибо @bekemax за помощь в правке английской версии текста и за интересные физические комментарии.\n",
    "\n",
    "Это завершает мое небольшое исследование <strong>Neural ODEs</strong>. Спасибо за внимание, надеюсь вам понравилось!\n",
    "\n",
    "<h1>Полезные ссылки</h1>\n",
    "<ul>\n",
    "<li><a href=\"https://arxiv.org/abs/1806.07366\">Оригинальная статья</a>\n",
    "</li>\n",
    "<li><a href=\"https://github.com/rtqichen/torchdiffeq\">Авторский репозиторий</a>\n",
    "</li>\n",
    "<li><a href=\"https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf\">Вариационный вывод</a>\n",
    "</li>\n",
    "<li><a href=\"https://habr.com/en/post/331552/\">Моя статья про VAE (Русский)</a>\n",
    "</li>\n",
    "<li><a href=\"https://www.jeremyjordan.me/variational-autoencoders/\">Объяснение VAE</a>\n",
    "</li>\n",
    "<li><a href=\"http://akosiorek.github.io/ml/2018/04/03/norm_flows.html\">Больше про нормализующие потоки</a>\n",
    "</li>\n",
    "<li><a href=\"https://arxiv.org/abs/1505.05770\">Variational Inference with Normalizing Flows Paper</a>\n",
    "</li>\n",
    "</ul>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "<img align=\"center\" src=\"//tex.s2cms.ru/svg/%0A%5Cfrac%7Bdz%7D%7Bdt%7D%20%3D%20f(z(t)%2C%20t)%20%5C%3B%20(1)%0A\" alt=\"\n",
    "\\frac{dz}{dt} = f(z(t), t) \\; (1)\n",
    "\" />\n",
    "\n",
    "<img align=\"center\" src=\"https://tex.s2cms.ru/svg/%0A%5Cfrac%7Bdz%7D%7Bdt%7D%20%3D%20f(z(t)%2C%20t)%20%5C%3B%20(1)%0A\" alt=\"\n",
    "\\frac{dz}{dt} = f(z(t), t) \\; (1)\n",
    "\" />"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "01447e49744f4e15a918eab38023ee12": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "048637d9cc1145f09927902862f40951": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "05e09cca05514276a8729c3d97e57437": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "065445947c66425282ee8aaf7954a120": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "0775b035e93647aa9f244cae1ff06eab": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "07af261c59f6431fb4793235e4b161b2": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "0b6e4731c92f432689b4cc7c385392d4": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_46b9cd6a486b4ada8e8e39209763eabc",
       "style": "IPY_MODEL_7ba08023e68346c99d295c8cf317c9c6",
       "value": "  0% 0/79 [00:00&lt;?, ?it/s]"
      }
     },
     "0b85dee07d4749b18aef1dabf03bb1f4": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "0caf938964ce48848a446461157dfeb2": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_eb836fe776da47d8b89c0b01ecb89d98",
        "IPY_MODEL_7722c8d516c14ea7941084c5844c6162"
       ],
       "layout": "IPY_MODEL_7007118bd6734014b09e7e49477c0c4a"
      }
     },
     "0df95b684157426880512578f9f78d1f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "0e7868acdb7c4c619625d37deb9eb3a3": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "0ea17b3b30934876a6770551876e249a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "0f5ee0f14be24040a03c395b2231c00e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "10784fae38df4b7d89198aace1bd64a7": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_f8a3eabaf27c46ca80f37df959abdbb4",
        "IPY_MODEL_279a240511db48b7bf70ab42356266b0"
       ],
       "layout": "IPY_MODEL_7831b7d3b6c54b538bea298f096fb536"
      }
     },
     "123379c3ee7743cba4ed7ec34ae86803": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "13fbea0ae8394bc586befbd1c8b4ae7b": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_96a3541c181347b39e533bcedf25f0d4",
        "IPY_MODEL_47f6d7e9386e43809947d74f5fbead50"
       ],
       "layout": "IPY_MODEL_65030196eaca4611bf4bf974a38610d6"
      }
     },
     "15905bc3a2bf4b29aebfe0fbecbefb11": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_475f88c0c44e4c4ca73a6494f280bcd9",
        "IPY_MODEL_5449f0ee3b8e487f906ee3677e58fb9b"
       ],
       "layout": "IPY_MODEL_7bede11c51444425aabed22076a37eb1"
      }
     },
     "16f0c997619d429785f07a3b41574900": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_90af5c7af9cc46e4b2e1dde7a04ce01a",
       "style": "IPY_MODEL_94a492c2964f4a7baf815dad7841b6db",
       "value": "100% 79/79 [00:01&lt;00:00, 45.89it/s]"
      }
     },
     "16f7d8c07e5844c4bed41ab1eabfa8e1": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "173a8399af0a4f20b483e0985bf17480": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_459b7bea4d514b53b67e970855d95de7",
        "IPY_MODEL_16f0c997619d429785f07a3b41574900"
       ],
       "layout": "IPY_MODEL_123379c3ee7743cba4ed7ec34ae86803"
      }
     },
     "1754cba9ef2d4f00b605b2852eeef3f7": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_30c723e20df74bce9aec7c29e280788d",
        "IPY_MODEL_54f60614bb0246d8b157793b021806e2"
       ],
       "layout": "IPY_MODEL_2ad0b5aab3c14fd9b193f5fe48ba8d9e"
      }
     },
     "1806bd82afb64886a188425bf61de3ee": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "1853c04636984ec6a178826a85930c1b": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "18983393e5414f6f858e8c034d9fb9bc": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "19c7b3b12ec3416299db573bd6984df5": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "1a9b8f506fca4149a2b2cd992b670785": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "1b92e09380d34936a89235d115570832": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "1d8be434b18e433b98f4321b257e3274": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "1e707ca751194ec2a7a14564648bd27a": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "1e92b05840124fe7bf8d9bd143da48cd": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_3d18d316f3e44e479038dfc5c8bc306b",
       "style": "IPY_MODEL_dcf78f860cc4483d9741b04aa3d6d1db",
       "value": "100% 79/79 [00:01&lt;00:00, 46.64it/s]"
      }
     },
     "1f89084ff986444bbd62198e126bcde4": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "1ffd204fab354f9f921d259346a47742": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "2220a84fa7b3415e8e5c775585eb0323": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "23b650880517487192d3a539ce2eee33": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_7b73e5216e7840faa3eb2662a0c01b2a",
       "max": 1875,
       "style": "IPY_MODEL_0775b035e93647aa9f244cae1ff06eab",
       "value": 1875
      }
     },
     "2524fdd6758a4d6db92eb448c6f025f8": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "279a240511db48b7bf70ab42356266b0": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_86c82f8a014e4048879c0d3298efdda2",
       "style": "IPY_MODEL_8eb957262d244a90a6f5d70e3c1195d9",
       "value": "100% 79/79 [00:01&lt;00:00, 46.80it/s]"
      }
     },
     "290d0742411a4702ae281fea70f3bcaa": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_fe681ff909f441eea634e7bb6169422b",
       "style": "IPY_MODEL_c27ea7e05da146a2ae2b4713ebcd38c0",
       "value": "100% 79/79 [00:01&lt;00:00, 44.72it/s]"
      }
     },
     "2a521e703d1a4df598d4d7ebc8917df1": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "2ad0b5aab3c14fd9b193f5fe48ba8d9e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "2d5c4ccf97f641f4b36c516e49b42f1c": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_bc63d0e73204431c8d36c5cc5851b74c",
       "style": "IPY_MODEL_16f7d8c07e5844c4bed41ab1eabfa8e1",
       "value": "100% 1875/1875 [01:17&lt;00:00, 24.32it/s]"
      }
     },
     "3022b9630fe343639f36b7d80ef877eb": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_1ffd204fab354f9f921d259346a47742",
       "max": 79,
       "style": "IPY_MODEL_7ae4bac9876546f0b92d11366637d861",
       "value": 79
      }
     },
     "30c723e20df74bce9aec7c29e280788d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_e6b679d3e9d544cfb613164b03c8991d",
       "max": 1875,
       "style": "IPY_MODEL_dc1ee5808f0946eaa79fd898647e0ba2",
       "value": 1875
      }
     },
     "31beb781955e4abea1fa33138a70ca69": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "32b30cb9acc244a9914e34f45879c21e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_4287d4cb0bb1408797be58903995174a",
       "style": "IPY_MODEL_5a734a68c8c04c6b96a31d7d4aa90f5f",
       "value": "100% 1875/1875 [01:15&lt;00:00, 24.69it/s]"
      }
     },
     "33497dd7f39b42da8ce8c28ae0eb1ebf": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_4b4e2cc965cc4c4088ea23c0b05484b9",
       "max": 1875,
       "style": "IPY_MODEL_6449f6d43b3e458bb419bcc1198fd5b9",
       "value": 1875
      }
     },
     "34a8733ed250438fb584cfaa46c00ab3": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_656337147cc34e5494d9c722d65496ba",
        "IPY_MODEL_fffa785bfcd54cb49a1c2a420bac2b51"
       ],
       "layout": "IPY_MODEL_9bb7628377ca49989eeeddf2c4d5f4c8"
      }
     },
     "34b6a9403d5b42e695cadb3a3446004a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "3538802d34dc4533b03785643f57235d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "36e2fc573d5d45c4bd9ed063f5e14954": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_87334fc2d276402e917d0517382e8b42",
       "max": 79,
       "style": "IPY_MODEL_4351044a73984ba19c3c1ba8fa67a921",
       "value": 79
      }
     },
     "3727ef4991cb4f7f9ef7d007df01f449": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "3bdc65622ceb44fa895eeccf2ecc0152": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "3c18e69d054d4a259f9b75cb5f2f57f5": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "3d18d316f3e44e479038dfc5c8bc306b": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "3dcea128ee1c47fabd9c1ca1513cbe0e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_3727ef4991cb4f7f9ef7d007df01f449",
       "style": "IPY_MODEL_d3f17e88798a4f6a9e33e9005df7c7f9",
       "value": "100% 79/79 [00:01&lt;00:00, 45.42it/s]"
      }
     },
     "3e54bef0104e4979b134823ac90c4fa6": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "3f6fa3cbad674dcb93ae3b17bcfb60d6": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_481f434e06ba4366b43eece1c7200a33",
       "max": 1875,
       "style": "IPY_MODEL_ba07dd1810c74b6ba518f6221168913b",
       "value": 1875
      }
     },
     "40f5b79337564201be954dda50b6ea67": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "415f6e33dce140f3ac3935aa6769bce9": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "4287d4cb0bb1408797be58903995174a": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "4351044a73984ba19c3c1ba8fa67a921": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "438a27adf49445deaa10c31609d44516": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_69f3ab8664b44198aad0fe0ecc18f08c",
        "IPY_MODEL_7ddab42457f440d59ede14a2b662c324"
       ],
       "layout": "IPY_MODEL_0f5ee0f14be24040a03c395b2231c00e"
      }
     },
     "44766dc4803a4a5aa6a8c752d0ae2049": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_2220a84fa7b3415e8e5c775585eb0323",
       "max": 79,
       "style": "IPY_MODEL_2524fdd6758a4d6db92eb448c6f025f8",
       "value": 79
      }
     },
     "44f2ee887e874370bc4b4d0330c843af": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "459b7bea4d514b53b67e970855d95de7": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_f209fbc7469f40e7963d4c18bae6e111",
       "max": 79,
       "style": "IPY_MODEL_1d8be434b18e433b98f4321b257e3274",
       "value": 79
      }
     },
     "46b9cd6a486b4ada8e8e39209763eabc": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "475f88c0c44e4c4ca73a6494f280bcd9": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_e199254644b34a6395dcc41cee2bd738",
       "max": 1875,
       "style": "IPY_MODEL_0b85dee07d4749b18aef1dabf03bb1f4",
       "value": 1875
      }
     },
     "47f6d7e9386e43809947d74f5fbead50": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_fffd484e385e4249bb9759808ee02fac",
       "style": "IPY_MODEL_9f4929ea841e4ab3960b47003ad1d648",
       "value": "100% 79/79 [00:01&lt;00:00, 43.20it/s]"
      }
     },
     "481f434e06ba4366b43eece1c7200a33": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "4a6e93fc0e6c47ff830e00836304d3f4": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_b0eb50e2e02740d4a166cb4ced86c574",
        "IPY_MODEL_b4e3c1c1e8f64d7d8158e341efda2edb"
       ],
       "layout": "IPY_MODEL_6766a5a8c8f14f3daf0eb3013746dc9f"
      }
     },
     "4a9b90eef6b94bcc9e7919fbba4419fb": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_728662f43b604263a57237acc830f2fa",
       "max": 79,
       "style": "IPY_MODEL_8a28e7cf467f422bafcc4a4f7f80d99d",
       "value": 79
      }
     },
     "4b4e2cc965cc4c4088ea23c0b05484b9": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "4bbe6a4019384880accf89bd6ea4ce43": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "4c1560d1809d4a1ea741d43dcef95076": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "529def52c63f4a29b578b1ece833bd93": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "5300061d70b04209a9b59fdd6493424c": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "5449f0ee3b8e487f906ee3677e58fb9b": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_40f5b79337564201be954dda50b6ea67",
       "style": "IPY_MODEL_6ddadc736a5845f198d128a1c0c0e94d",
       "value": "100% 1875/1875 [01:16&lt;00:00, 24.63it/s]"
      }
     },
     "54f60614bb0246d8b157793b021806e2": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_803ab43df57c4b49a0c9400d794ba6c3",
       "style": "IPY_MODEL_1f89084ff986444bbd62198e126bcde4",
       "value": "100% 1875/1875 [01:16&lt;00:00, 24.40it/s]"
      }
     },
     "54fef2fc76454dbc87f85d678e570fca": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "57541d738454477198d5f0ccc4521703": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_048637d9cc1145f09927902862f40951",
       "style": "IPY_MODEL_6ee900a5b4a44c46afefd5dbc277089a",
       "value": "100% 1875/1875 [01:17&lt;00:00, 24.28it/s]"
      }
     },
     "5a6b2b387a7141bfbe7d6826f12f54eb": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "5a734a68c8c04c6b96a31d7d4aa90f5f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "5a9947e7eab6439b880b24ee85c27c61": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "6027f1f6d186491e9b1c6a0b4488f602": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "6262c5576e884ab7886d6e186df54ae8": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "63395bfc4f0e433588c3562efc322ddb": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_743078593d73460abf57d9429fa92c8b",
       "style": "IPY_MODEL_9568817c86d54a68b638bb217be5696d",
       "value": "100% 79/79 [00:01&lt;00:00, 45.65it/s]"
      }
     },
     "63cd0b359ca748c4a0aae0e09a4fbb3c": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_da9d36f1abb3470f89dcdd21931dcfcd",
        "IPY_MODEL_57541d738454477198d5f0ccc4521703"
       ],
       "layout": "IPY_MODEL_07af261c59f6431fb4793235e4b161b2"
      }
     },
     "6449f6d43b3e458bb419bcc1198fd5b9": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "65030196eaca4611bf4bf974a38610d6": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "656337147cc34e5494d9c722d65496ba": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_1a9b8f506fca4149a2b2cd992b670785",
       "max": 79,
       "style": "IPY_MODEL_9b94b49d7ef44cba8d0b514159407f8d",
       "value": 79
      }
     },
     "675221f89e2140ef8d60204390e9acda": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "6766a5a8c8f14f3daf0eb3013746dc9f": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "69523589bb43433cb1de06bb116e8958": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "69a43112476f4730b7c8d60e2b301e10": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "69f3ab8664b44198aad0fe0ecc18f08c": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_1e707ca751194ec2a7a14564648bd27a",
       "max": 79,
       "style": "IPY_MODEL_c292aa2c40e347418d62e2d8befd013d",
       "value": 79
      }
     },
     "6cfd09d8ff154c2786daa740299f3cf2": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "6ddadc736a5845f198d128a1c0c0e94d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "6df4ed67c40043e59585f8160fb32156": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_36e2fc573d5d45c4bd9ed063f5e14954",
        "IPY_MODEL_63395bfc4f0e433588c3562efc322ddb"
       ],
       "layout": "IPY_MODEL_97d837fcb3944443970fb37175e2a8e2"
      }
     },
     "6e17b8245d8e4202ba9b9cc8b2a16282": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_69523589bb43433cb1de06bb116e8958",
       "max": 79,
       "style": "IPY_MODEL_9403280d739a4430938626d044d78ba6",
       "value": 79
      }
     },
     "6ea469953ee04cc0bab8260495313f27": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "6ee900a5b4a44c46afefd5dbc277089a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "7007118bd6734014b09e7e49477c0c4a": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "728662f43b604263a57237acc830f2fa": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "73790d724d814ee2bb3442420f1efa87": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "743078593d73460abf57d9429fa92c8b": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "75a786ad7f7c459fa6f108319910e93c": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "7678625a957942feb9d1f19894db4207": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_f53c52d6047a41fb9fca83d28d5cf184",
       "style": "IPY_MODEL_31beb781955e4abea1fa33138a70ca69",
       "value": "100% 1875/1875 [01:16&lt;00:00, 24.50it/s]"
      }
     },
     "7722c8d516c14ea7941084c5844c6162": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_95ee7405b95748a691751e8e3674cee7",
       "style": "IPY_MODEL_d10f04a4086c4d0a82af922259aa8dfd",
       "value": "100% 1875/1875 [01:15&lt;00:00, 24.69it/s]"
      }
     },
     "7831b7d3b6c54b538bea298f096fb536": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "7ae4bac9876546f0b92d11366637d861": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "7b73e5216e7840faa3eb2662a0c01b2a": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "7ba08023e68346c99d295c8cf317c9c6": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "7bede11c51444425aabed22076a37eb1": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "7c36e77d1ef14e1bbcb55f3390ac0f97": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "7c5e8e8b8fec405dbf312c03d8bc8919": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "danger",
       "layout": "IPY_MODEL_75a786ad7f7c459fa6f108319910e93c",
       "max": 79,
       "style": "IPY_MODEL_529def52c63f4a29b578b1ece833bd93"
      }
     },
     "7d44bc9bc4754f09b48ba223a1311fb0": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_415f6e33dce140f3ac3935aa6769bce9",
       "max": 1875,
       "style": "IPY_MODEL_8f9d3d77e4e740388b5d7a15bd09e57f",
       "value": 1875
      }
     },
     "7ddab42457f440d59ede14a2b662c324": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_18983393e5414f6f858e8c034d9fb9bc",
       "style": "IPY_MODEL_73790d724d814ee2bb3442420f1efa87",
       "value": "100% 79/79 [00:01&lt;00:00, 47.10it/s]"
      }
     },
     "803ab43df57c4b49a0c9400d794ba6c3": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "81b4d8b1315742a183033513b0459455": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_54fef2fc76454dbc87f85d678e570fca",
       "style": "IPY_MODEL_afc4b1da754d42a7a5fb84526e3b139c",
       "value": "100% 1875/1875 [01:16&lt;00:00, 24.67it/s]"
      }
     },
     "8271d20551384a1f9c2c3fe6bfa37094": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "832391149cbb481e90d8fa94041f6b0e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_a0d36ce229f14b88976975c250e2ca3b",
       "max": 79,
       "style": "IPY_MODEL_1806bd82afb64886a188425bf61de3ee",
       "value": 79
      }
     },
     "8480fea52e3d488aae8d090a01f5688e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_910073d15395494f9ec3efa8bf691e9e",
       "style": "IPY_MODEL_fcc3a232659a4a2a83d065a787c26402",
       "value": "100% 79/79 [00:01&lt;00:00, 45.69it/s]"
      }
     },
     "8531fb26ddc84ed9ac53bbf1acdf5139": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_944251061ed24845865095e543468fb1",
       "style": "IPY_MODEL_8d807d05f54f4175a3ff616d65083830",
       "value": "100% 79/79 [00:01&lt;00:00, 46.50it/s]"
      }
     },
     "86c82f8a014e4048879c0d3298efdda2": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "86e7222efdff4a8392ad8ac37928440a": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "87334fc2d276402e917d0517382e8b42": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "8a28e7cf467f422bafcc4a4f7f80d99d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "8cb88a4d7a6443aea442fc31f7bb3eee": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "8d807d05f54f4175a3ff616d65083830": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "8eb957262d244a90a6f5d70e3c1195d9": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "8f9d3d77e4e740388b5d7a15bd09e57f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "90af5c7af9cc46e4b2e1dde7a04ce01a": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "910073d15395494f9ec3efa8bf691e9e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "925b89ad27154b26ad4834342ea0e497": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_b0262c82ff9346258f175008712f31c4",
       "style": "IPY_MODEL_0df95b684157426880512578f9f78d1f",
       "value": "100% 79/79 [00:01&lt;00:00, 46.11it/s]"
      }
     },
     "93a9f35b76de44c0aaf776946347b76e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_baf09de2e4e645fa88b93a5a82fa1664",
        "IPY_MODEL_2d5c4ccf97f641f4b36c516e49b42f1c"
       ],
       "layout": "IPY_MODEL_3bdc65622ceb44fa895eeccf2ecc0152"
      }
     },
     "9403280d739a4430938626d044d78ba6": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "944251061ed24845865095e543468fb1": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "94a492c2964f4a7baf815dad7841b6db": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "9568817c86d54a68b638bb217be5696d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "95ee7405b95748a691751e8e3674cee7": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "96a3541c181347b39e533bcedf25f0d4": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_9d9e984f53ef402ab5d1642ce9aaa392",
       "max": 79,
       "style": "IPY_MODEL_0e7868acdb7c4c619625d37deb9eb3a3",
       "value": 79
      }
     },
     "97d837fcb3944443970fb37175e2a8e2": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "9b94b49d7ef44cba8d0b514159407f8d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "9ba772bad6514a60b6a67b5801836a54": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_01447e49744f4e15a918eab38023ee12",
       "max": 79,
       "style": "IPY_MODEL_69a43112476f4730b7c8d60e2b301e10",
       "value": 79
      }
     },
     "9bb7628377ca49989eeeddf2c4d5f4c8": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "9d9e984f53ef402ab5d1642ce9aaa392": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "9f4929ea841e4ab3960b47003ad1d648": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "a0d36ce229f14b88976975c250e2ca3b": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "a167fa6ae8a64dd0a6a5e76074c13c70": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_05e09cca05514276a8729c3d97e57437",
       "style": "IPY_MODEL_3538802d34dc4533b03785643f57235d",
       "value": "100% 1875/1875 [01:17&lt;00:00, 24.13it/s]"
      }
     },
     "a18ce0988e894226b3ae3fec703d1114": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_44766dc4803a4a5aa6a8c752d0ae2049",
        "IPY_MODEL_8480fea52e3d488aae8d090a01f5688e"
       ],
       "layout": "IPY_MODEL_19c7b3b12ec3416299db573bd6984df5"
      }
     },
     "a2b7b5f4fe3f437c810ee41d2e75ace7": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "a9d79f7b477d47e2a2d6f487f6736c97": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_7c36e77d1ef14e1bbcb55f3390ac0f97",
       "style": "IPY_MODEL_1b92e09380d34936a89235d115570832",
       "value": "100% 79/79 [00:01&lt;00:00, 46.85it/s]"
      }
     },
     "aeea7f2d543645a1a4650e63048c5e21": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_3022b9630fe343639f36b7d80ef877eb",
        "IPY_MODEL_a9d79f7b477d47e2a2d6f487f6736c97"
       ],
       "layout": "IPY_MODEL_b2cdb88885a54889a24905348828163c"
      }
     },
     "afa037fad5894728ac3f93bf6cfe91e7": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "afc4b1da754d42a7a5fb84526e3b139c": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "b0262c82ff9346258f175008712f31c4": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "b0eb50e2e02740d4a166cb4ced86c574": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_2a521e703d1a4df598d4d7ebc8917df1",
       "max": 79,
       "style": "IPY_MODEL_cb2e7e4d6dcd4644bf6bcc6379a2eebf",
       "value": 79
      }
     },
     "b2cdb88885a54889a24905348828163c": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "b4e3c1c1e8f64d7d8158e341efda2edb": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_5300061d70b04209a9b59fdd6493424c",
       "style": "IPY_MODEL_0ea17b3b30934876a6770551876e249a",
       "value": "100% 79/79 [00:01&lt;00:00, 46.76it/s]"
      }
     },
     "b7b7fe61a5094687b3b471fd96a9c3f0": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_3e54bef0104e4979b134823ac90c4fa6",
       "style": "IPY_MODEL_6ea469953ee04cc0bab8260495313f27",
       "value": "100% 1875/1875 [01:17&lt;00:00, 24.34it/s]"
      }
     },
     "b94d8094b6e24f9cb5f82e15b6b7a8df": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_3c18e69d054d4a259f9b75cb5f2f57f5",
       "style": "IPY_MODEL_e5b247044ffc47b5aeca6bfe487a7f28",
       "value": "100% 79/79 [00:01&lt;00:00, 46.40it/s]"
      }
     },
     "ba07dd1810c74b6ba518f6221168913b": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "baf09de2e4e645fa88b93a5a82fa1664": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_dfea0c5e412b4a98827d3c14e7cd07f7",
       "max": 1875,
       "style": "IPY_MODEL_8cb88a4d7a6443aea442fc31f7bb3eee",
       "value": 1875
      }
     },
     "bc63d0e73204431c8d36c5cc5851b74c": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "beb4598548a6448e8f7a80bb8683bb05": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "bf6b4ccf1cbc413b8fb85fa5b7fc36cd": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_cd8c8dc54ce7434ea8d1691d720fe360",
        "IPY_MODEL_b7b7fe61a5094687b3b471fd96a9c3f0"
       ],
       "layout": "IPY_MODEL_ee020566515641e1b6ddb5a047e1be9a"
      }
     },
     "c27ea7e05da146a2ae2b4713ebcd38c0": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "c292aa2c40e347418d62e2d8befd013d": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "c979ea3c350142b6b8105842c836e5f8": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_6e17b8245d8e4202ba9b9cc8b2a16282",
        "IPY_MODEL_290d0742411a4702ae281fea70f3bcaa"
       ],
       "layout": "IPY_MODEL_5a6b2b387a7141bfbe7d6826f12f54eb"
      }
     },
     "cb26e37966c94d6098dd962de92a45a0": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_d833f8a0581c433697eb420645e80c6e",
       "max": 79,
       "style": "IPY_MODEL_e956690c18f34ad6bc3fdf3b6fe9cf37",
       "value": 79
      }
     },
     "cb2e7e4d6dcd4644bf6bcc6379a2eebf": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "cb99dce53a4a4288a52f0b4b831ad906": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_fb46e5e3660f468e8a4d934bd7672907",
       "max": 79,
       "style": "IPY_MODEL_6262c5576e884ab7886d6e186df54ae8",
       "value": 79
      }
     },
     "cc42d02ea372471d9a4c420d65f26b4a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_23b650880517487192d3a539ce2eee33",
        "IPY_MODEL_a167fa6ae8a64dd0a6a5e76074c13c70"
       ],
       "layout": "IPY_MODEL_1853c04636984ec6a178826a85930c1b"
      }
     },
     "cd8c8dc54ce7434ea8d1691d720fe360": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_eea954d8b59a435abd4d0a41f5059c4e",
       "max": 1875,
       "style": "IPY_MODEL_675221f89e2140ef8d60204390e9acda",
       "value": 1875
      }
     },
     "d10f04a4086c4d0a82af922259aa8dfd": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "d3f17e88798a4f6a9e33e9005df7c7f9": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "d802fe45395847869f7925a7dc7394e2": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "d833f8a0581c433697eb420645e80c6e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "da9d36f1abb3470f89dcdd21931dcfcd": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_86e7222efdff4a8392ad8ac37928440a",
       "max": 1875,
       "style": "IPY_MODEL_5a9947e7eab6439b880b24ee85c27c61",
       "value": 1875
      }
     },
     "dc1ee5808f0946eaa79fd898647e0ba2": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "dcf78f860cc4483d9741b04aa3d6d1db": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "de307ce4268d41fb918d154ea0409dbd": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_cb26e37966c94d6098dd962de92a45a0",
        "IPY_MODEL_925b89ad27154b26ad4834342ea0e497"
       ],
       "layout": "IPY_MODEL_fe91d8419f6b456f9e66d219f5991369"
      }
     },
     "dfea0c5e412b4a98827d3c14e7cd07f7": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "e139b8cfb0304b9dbfe35e0418dcaa07": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_9ba772bad6514a60b6a67b5801836a54",
        "IPY_MODEL_1e92b05840124fe7bf8d9bd143da48cd"
       ],
       "layout": "IPY_MODEL_afa037fad5894728ac3f93bf6cfe91e7"
      }
     },
     "e199254644b34a6395dcc41cee2bd738": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "e1cc23676b87406fbd3840b661dfb5e7": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_cb99dce53a4a4288a52f0b4b831ad906",
        "IPY_MODEL_8531fb26ddc84ed9ac53bbf1acdf5139"
       ],
       "layout": "IPY_MODEL_6027f1f6d186491e9b1c6a0b4488f602"
      }
     },
     "e42169a0d360481ebd3c529cc421c2a5": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_832391149cbb481e90d8fa94041f6b0e",
        "IPY_MODEL_3dcea128ee1c47fabd9c1ca1513cbe0e"
       ],
       "layout": "IPY_MODEL_f65aa846efdf411bb970b52d69e6e3a5"
      }
     },
     "e5b247044ffc47b5aeca6bfe487a7f28": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "e6b679d3e9d544cfb613164b03c8991d": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "e70e510e7f5844b0a1d578f7679343e0": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_4a9b90eef6b94bcc9e7919fbba4419fb",
        "IPY_MODEL_b94d8094b6e24f9cb5f82e15b6b7a8df"
       ],
       "layout": "IPY_MODEL_beb4598548a6448e8f7a80bb8683bb05"
      }
     },
     "e83e862474bd4899ad386c004d8cb9a2": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_7c5e8e8b8fec405dbf312c03d8bc8919",
        "IPY_MODEL_0b6e4731c92f432689b4cc7c385392d4"
       ],
       "layout": "IPY_MODEL_a2b7b5f4fe3f437c810ee41d2e75ace7"
      }
     },
     "e921d08f6e2941b4a40a1cdf3167b7c2": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "e956690c18f34ad6bc3fdf3b6fe9cf37": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "ea36d1a94b36430f947fcf0962fabc40": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_33497dd7f39b42da8ce8c28ae0eb1ebf",
        "IPY_MODEL_32b30cb9acc244a9914e34f45879c21e"
       ],
       "layout": "IPY_MODEL_6cfd09d8ff154c2786daa740299f3cf2"
      }
     },
     "eb836fe776da47d8b89c0b01ecb89d98": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_d802fe45395847869f7925a7dc7394e2",
       "max": 1875,
       "style": "IPY_MODEL_34b6a9403d5b42e695cadb3a3446004a",
       "value": 1875
      }
     },
     "ee020566515641e1b6ddb5a047e1be9a": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "eea954d8b59a435abd4d0a41f5059c4e": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "f209fbc7469f40e7963d4c18bae6e111": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "f53c52d6047a41fb9fca83d28d5cf184": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "f612da8636d642e6b0d42ea2dfab928a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_3f6fa3cbad674dcb93ae3b17bcfb60d6",
        "IPY_MODEL_7678625a957942feb9d1f19894db4207"
       ],
       "layout": "IPY_MODEL_4bbe6a4019384880accf89bd6ea4ce43"
      }
     },
     "f65aa846efdf411bb970b52d69e6e3a5": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "f8a3eabaf27c46ca80f37df959abdbb4": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "IntProgressModel",
      "state": {
       "bar_style": "success",
       "layout": "IPY_MODEL_e921d08f6e2941b4a40a1cdf3167b7c2",
       "max": 79,
       "style": "IPY_MODEL_065445947c66425282ee8aaf7954a120",
       "value": 79
      }
     },
     "fb46e5e3660f468e8a4d934bd7672907": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "fcc3a232659a4a2a83d065a787c26402": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "description_width": ""
      }
     },
     "fe681ff909f441eea634e7bb6169422b": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "fe91d8419f6b456f9e66d219f5991369": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     },
     "ff9bf7cbddb6483a821d35aa872e4417": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HBoxModel",
      "state": {
       "children": [
        "IPY_MODEL_7d44bc9bc4754f09b48ba223a1311fb0",
        "IPY_MODEL_81b4d8b1315742a183033513b0459455"
       ],
       "layout": "IPY_MODEL_8271d20551384a1f9c2c3fe6bfa37094"
      }
     },
     "fffa785bfcd54cb49a1c2a420bac2b51": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.1.0",
      "model_name": "HTMLModel",
      "state": {
       "layout": "IPY_MODEL_44f2ee887e874370bc4b4d0330c843af",
       "style": "IPY_MODEL_4c1560d1809d4a1ea741d43dcef95076",
       "value": "100% 79/79 [00:01&lt;00:00, 44.05it/s]"
      }
     },
     "fffd484e385e4249bb9759808ee02fac": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.0.0",
      "model_name": "LayoutModel",
      "state": {}
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
